| //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines the types and operation details for the LLVM IR dialect in |
| // MLIR, and the LLVM IR dialect. It also registers the dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/FunctionImplementation.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/StandardTypes.h" |
| |
| #include "llvm/AsmParser/Parser.h" |
| #include "llvm/IR/Attributes.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/Type.h" |
| #include "llvm/Support/Mutex.h" |
| #include "llvm/Support/SourceMgr.h" |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| |
| #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::CmpOp. |
| //===----------------------------------------------------------------------===// |
| static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { |
| p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) |
| << "\" " << op.getOperand(0) << ", " << op.getOperand(1); |
| p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); |
| p << " : " << op.lhs().getType(); |
| } |
| |
| static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { |
| p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) |
| << "\" " << op.getOperand(0) << ", " << op.getOperand(1); |
| p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); |
| p << " : " << op.lhs().getType(); |
| } |
| |
| // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use |
| // attribute-dict? `:` type |
| // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use |
| // attribute-dict? `:` type |
| template <typename CmpPredicateType> |
| static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { |
| Builder &builder = parser.getBuilder(); |
| |
| StringAttr predicateAttr; |
| OpAsmParser::OperandType lhs, rhs; |
| Type type; |
| llvm::SMLoc predicateLoc, trailingTypeLoc; |
| if (parser.getCurrentLocation(&predicateLoc) || |
| parser.parseAttribute(predicateAttr, "predicate", result.attributes) || |
| parser.parseOperand(lhs) || parser.parseComma() || |
| parser.parseOperand(rhs) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || |
| parser.resolveOperand(lhs, type, result.operands) || |
| parser.resolveOperand(rhs, type, result.operands)) |
| return failure(); |
| |
| // Replace the string attribute `predicate` with an integer attribute. |
| int64_t predicateValue = 0; |
| if (std::is_same<CmpPredicateType, ICmpPredicate>()) { |
| Optional<ICmpPredicate> predicate = |
| symbolizeICmpPredicate(predicateAttr.getValue()); |
| if (!predicate) |
| return parser.emitError(predicateLoc) |
| << "'" << predicateAttr.getValue() |
| << "' is an incorrect value of the 'predicate' attribute"; |
| predicateValue = static_cast<int64_t>(predicate.getValue()); |
| } else { |
| Optional<FCmpPredicate> predicate = |
| symbolizeFCmpPredicate(predicateAttr.getValue()); |
| if (!predicate) |
| return parser.emitError(predicateLoc) |
| << "'" << predicateAttr.getValue() |
| << "' is an incorrect value of the 'predicate' attribute"; |
| predicateValue = static_cast<int64_t>(predicate.getValue()); |
| } |
| |
| result.attributes[0].second = |
| parser.getBuilder().getI64IntegerAttr(predicateValue); |
| |
| // The result type is either i1 or a vector type <? x i1> if the inputs are |
| // vectors. |
| auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>(); |
| auto resultType = LLVMType::getInt1Ty(dialect); |
| auto argType = type.dyn_cast<LLVM::LLVMType>(); |
| if (!argType) |
| return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); |
| if (argType.getUnderlyingType()->isVectorTy()) |
| resultType = LLVMType::getVectorTy( |
| resultType, argType.getUnderlyingType()->getVectorNumElements()); |
| |
| result.addTypes({resultType}); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::AllocaOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { |
| auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy(); |
| |
| auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()}, |
| op.getContext()); |
| |
| p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; |
| if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) |
| p.printOptionalAttrDict(op.getAttrs()); |
| else |
| p.printOptionalAttrDict(op.getAttrs(), {"alignment"}); |
| p << " : " << funcTy; |
| } |
| |
| // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? |
| // `:` type `,` type |
| static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { |
| OpAsmParser::OperandType arraySize; |
| Type type, elemType; |
| llvm::SMLoc trailingTypeLoc; |
| if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || |
| parser.parseType(elemType) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) |
| return failure(); |
| |
| // Extract the result type from the trailing function type. |
| auto funcType = type.dyn_cast<FunctionType>(); |
| if (!funcType || funcType.getNumInputs() != 1 || |
| funcType.getNumResults() != 1) |
| return parser.emitError( |
| trailingTypeLoc, |
| "expected trailing function type with one argument and one result"); |
| |
| if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) |
| return failure(); |
| |
| result.addTypes({funcType.getResult(0)}); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::LoadOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { |
| p << op.getOperationName() << ' ' << op.addr(); |
| p.printOptionalAttrDict(op.getAttrs()); |
| p << " : " << op.addr().getType(); |
| } |
| |
| // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return |
| // the resulting type wrapped in MLIR, or nullptr on error. |
| static Type getLoadStoreElementType(OpAsmParser &parser, Type type, |
| llvm::SMLoc trailingTypeLoc) { |
| auto llvmTy = type.dyn_cast<LLVM::LLVMType>(); |
| if (!llvmTy) |
| return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"), |
| nullptr; |
| if (!llvmTy.getUnderlyingType()->isPointerTy()) |
| return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), |
| nullptr; |
| return llvmTy.getPointerElementTy(); |
| } |
| |
| // <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type |
| static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { |
| OpAsmParser::OperandType addr; |
| Type type; |
| llvm::SMLoc trailingTypeLoc; |
| |
| if (parser.parseOperand(addr) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || |
| parser.resolveOperand(addr, type, result.operands)) |
| return failure(); |
| |
| Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); |
| |
| result.addTypes(elemTy); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::StoreOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { |
| p << op.getOperationName() << ' ' << op.value() << ", " << op.addr(); |
| p.printOptionalAttrDict(op.getAttrs()); |
| p << " : " << op.addr().getType(); |
| } |
| |
| // <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type |
| static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { |
| OpAsmParser::OperandType addr, value; |
| Type type; |
| llvm::SMLoc trailingTypeLoc; |
| |
| if (parser.parseOperand(value) || parser.parseComma() || |
| parser.parseOperand(addr) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) |
| return failure(); |
| |
| Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); |
| if (!elemTy) |
| return failure(); |
| |
| if (parser.resolveOperand(value, elemTy, result.operands) || |
| parser.resolveOperand(addr, type, result.operands)) |
| return failure(); |
| |
| return success(); |
| } |
| |
| ///===----------------------------------------------------------------------===// |
| /// Verifying/Printing/Parsing for LLVM::InvokeOp. |
| ///===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(InvokeOp op) { |
| if (op.getNumResults() > 1) |
| return op.emitOpError("must have 0 or 1 result"); |
| |
| Block *unwindDest = op.unwindDest(); |
| if (unwindDest->empty()) |
| return op.emitError( |
| "must have at least one operation in unwind destination"); |
| |
| // In unwind destination, first operation must be LandingpadOp |
| if (!isa<LandingpadOp>(unwindDest->front())) |
| return op.emitError("first operation in unwind destination should be a " |
| "llvm.landingpad operation"); |
| |
| return success(); |
| } |
| |
| static void printInvokeOp(OpAsmPrinter &p, InvokeOp &op) { |
| auto callee = op.callee(); |
| bool isDirect = callee.hasValue(); |
| |
| p << op.getOperationName() << ' '; |
| |
| // Either function name or pointer |
| if (isDirect) |
| p.printSymbolName(callee.getValue()); |
| else |
| p << op.getOperand(0); |
| |
| p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; |
| p << " to "; |
| p.printSuccessorAndUseList(op.getOperation(), 0); |
| p << " unwind "; |
| p.printSuccessorAndUseList(op.getOperation(), 1); |
| |
| p.printOptionalAttrDict(op.getAttrs(), {"callee"}); |
| |
| SmallVector<Type, 8> argTypes( |
| llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); |
| |
| p << " : " |
| << FunctionType::get(argTypes, op.getResultTypes(), op.getContext()); |
| } |
| |
| /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` |
| /// `to` bb-id (`[` ssa-use-and-type-list `]`)? |
| /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? |
| /// attribute-dict? `:` function-type |
| static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 8> operands; |
| FunctionType funcType; |
| SymbolRefAttr funcAttr; |
| llvm::SMLoc trailingTypeLoc; |
| Block *normalDest, *unwindDest; |
| SmallVector<Value, 4> normalOperands, unwindOperands; |
| |
| // Parse an operand list that will, in practice, contain 0 or 1 operand. In |
| // case of an indirect call, there will be 1 operand before `(`. In case of a |
| // direct call, there will be no operands and the parser will stop at the |
| // function identifier without complaining. |
| if (parser.parseOperandList(operands)) |
| return failure(); |
| bool isDirect = operands.empty(); |
| |
| // Optionally parse a function identifier. |
| if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) |
| return failure(); |
| |
| if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || |
| parser.parseKeyword("to") || |
| parser.parseSuccessorAndUseList(normalDest, normalOperands) || |
| parser.parseKeyword("unwind") || |
| parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) |
| return failure(); |
| |
| if (isDirect) { |
| // Make sure types match. |
| if (parser.resolveOperands(operands, funcType.getInputs(), |
| parser.getNameLoc(), result.operands)) |
| return failure(); |
| result.addTypes(funcType.getResults()); |
| } else { |
| // Construct the LLVM IR Dialect function type that the first operand |
| // should match. |
| if (funcType.getNumResults() > 1) |
| return parser.emitError(trailingTypeLoc, |
| "expected function with 0 or 1 result"); |
| |
| Builder &builder = parser.getBuilder(); |
| auto *llvmDialect = |
| builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); |
| LLVM::LLVMType llvmResultType; |
| if (funcType.getNumResults() == 0) { |
| llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); |
| } else { |
| llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>(); |
| if (!llvmResultType) |
| return parser.emitError(trailingTypeLoc, |
| "expected result to have LLVM type"); |
| } |
| |
| SmallVector<LLVM::LLVMType, 8> argTypes; |
| argTypes.reserve(funcType.getNumInputs()); |
| for (Type ty : funcType.getInputs()) { |
| if (auto argType = ty.dyn_cast<LLVM::LLVMType>()) |
| argTypes.push_back(argType); |
| else |
| return parser.emitError(trailingTypeLoc, |
| "expected LLVM types as inputs"); |
| } |
| |
| auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, |
| /*isVarArg=*/false); |
| auto wrappedFuncType = llvmFuncType.getPointerTo(); |
| |
| auto funcArguments = llvm::makeArrayRef(operands).drop_front(); |
| |
| // Make sure that the first operand (indirect callee) matches the wrapped |
| // LLVM IR function type, and that the types of the other call operands |
| // match the types of the function arguments. |
| if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || |
| parser.resolveOperands(funcArguments, funcType.getInputs(), |
| parser.getNameLoc(), result.operands)) |
| return failure(); |
| |
| result.addTypes(llvmResultType); |
| } |
| result.addSuccessor(normalDest, normalOperands); |
| result.addSuccessor(unwindDest, unwindOperands); |
| return success(); |
| } |
| |
| ///===----------------------------------------------------------------------===// |
| /// Verifying/Printing/Parsing for LLVM::LandingpadOp. |
| ///===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(LandingpadOp op) { |
| Value value; |
| |
| if (!op.cleanup() && op.getOperands().empty()) |
| return op.emitError("landingpad instruction expects at least one clause or " |
| "cleanup attribute"); |
| |
| for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { |
| value = op.getOperand(idx); |
| bool isFilter = value.getType().cast<LLVMType>().isArrayTy(); |
| if (isFilter) { |
| // FIXME: Verify filter clauses when arrays are appropriately handled |
| } else { |
| // catch - global addresses only. |
| // Bitcast ops should have global addresses as their args. |
| if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) { |
| if (auto addrOp = |
| dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp())) |
| continue; |
| return op.emitError("constant clauses expected") |
| .attachNote(bcOp.getLoc()) |
| << "global addresses expected as operand to " |
| "bitcast used in clauses for landingpad"; |
| } |
| // NullOp and AddressOfOp allowed |
| if (dyn_cast_or_null<NullOp>(value.getDefiningOp())) |
| continue; |
| if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp())) |
| continue; |
| return op.emitError("clause #") |
| << idx << " is not a known constant - null, addressof, bitcast"; |
| } |
| } |
| return success(); |
| } |
| |
| static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { |
| p << op.getOperationName() << (op.cleanup() ? " cleanup " : " "); |
| |
| // Clauses |
| for (auto value : op.getOperands()) { |
| // Similar to llvm - if clause is an array type then it is filter |
| // clause else catch clause |
| bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy(); |
| p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " |
| << value.getType() << ") "; |
| } |
| |
| p.printOptionalAttrDict(op.getAttrs(), {"cleanup"}); |
| |
| p << ": " << op.getType(); |
| } |
| |
| /// <operation> ::= `llvm.landingpad` `cleanup`? |
| /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? |
| static ParseResult parseLandingpadOp(OpAsmParser &parser, |
| OperationState &result) { |
| // Check for cleanup |
| if (succeeded(parser.parseOptionalKeyword("cleanup"))) |
| result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); |
| |
| // Parse clauses with types |
| while (succeeded(parser.parseOptionalLParen()) && |
| (succeeded(parser.parseOptionalKeyword("filter")) || |
| succeeded(parser.parseOptionalKeyword("catch")))) { |
| OpAsmParser::OperandType operand; |
| Type ty; |
| if (parser.parseOperand(operand) || parser.parseColon() || |
| parser.parseType(ty) || |
| parser.resolveOperand(operand, ty, result.operands) || |
| parser.parseRParen()) |
| return failure(); |
| } |
| |
| Type type; |
| if (parser.parseColon() || parser.parseType(type)) |
| return failure(); |
| |
| result.addTypes(type); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::CallOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printCallOp(OpAsmPrinter &p, CallOp &op) { |
| auto callee = op.callee(); |
| bool isDirect = callee.hasValue(); |
| |
| // Print the direct callee if present as a function attribute, or an indirect |
| // callee (first operand) otherwise. |
| p << op.getOperationName() << ' '; |
| if (isDirect) |
| p.printSymbolName(callee.getValue()); |
| else |
| p << op.getOperand(0); |
| |
| p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; |
| p.printOptionalAttrDict(op.getAttrs(), {"callee"}); |
| |
| // Reconstruct the function MLIR function type from operand and result types. |
| SmallVector<Type, 8> argTypes( |
| llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); |
| |
| p << " : " |
| << FunctionType::get(argTypes, op.getResultTypes(), op.getContext()); |
| } |
| |
| // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` |
| // attribute-dict? `:` function-type |
| static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 8> operands; |
| Type type; |
| SymbolRefAttr funcAttr; |
| llvm::SMLoc trailingTypeLoc; |
| |
| // Parse an operand list that will, in practice, contain 0 or 1 operand. In |
| // case of an indirect call, there will be 1 operand before `(`. In case of a |
| // direct call, there will be no operands and the parser will stop at the |
| // function identifier without complaining. |
| if (parser.parseOperandList(operands)) |
| return failure(); |
| bool isDirect = operands.empty(); |
| |
| // Optionally parse a function identifier. |
| if (isDirect) |
| if (parser.parseAttribute(funcAttr, "callee", result.attributes)) |
| return failure(); |
| |
| if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) |
| return failure(); |
| |
| auto funcType = type.dyn_cast<FunctionType>(); |
| if (!funcType) |
| return parser.emitError(trailingTypeLoc, "expected function type"); |
| if (isDirect) { |
| // Make sure types match. |
| if (parser.resolveOperands(operands, funcType.getInputs(), |
| parser.getNameLoc(), result.operands)) |
| return failure(); |
| result.addTypes(funcType.getResults()); |
| } else { |
| // Construct the LLVM IR Dialect function type that the first operand |
| // should match. |
| if (funcType.getNumResults() > 1) |
| return parser.emitError(trailingTypeLoc, |
| "expected function with 0 or 1 result"); |
| |
| Builder &builder = parser.getBuilder(); |
| auto *llvmDialect = |
| builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); |
| LLVM::LLVMType llvmResultType; |
| if (funcType.getNumResults() == 0) { |
| llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); |
| } else { |
| llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>(); |
| if (!llvmResultType) |
| return parser.emitError(trailingTypeLoc, |
| "expected result to have LLVM type"); |
| } |
| |
| SmallVector<LLVM::LLVMType, 8> argTypes; |
| argTypes.reserve(funcType.getNumInputs()); |
| for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { |
| auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>(); |
| if (!argType) |
| return parser.emitError(trailingTypeLoc, |
| "expected LLVM types as inputs"); |
| argTypes.push_back(argType); |
| } |
| auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, |
| /*isVarArg=*/false); |
| auto wrappedFuncType = llvmFuncType.getPointerTo(); |
| |
| auto funcArguments = |
| ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); |
| |
| // Make sure that the first operand (indirect callee) matches the wrapped |
| // LLVM IR function type, and that the types of the other call operands |
| // match the types of the function arguments. |
| if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || |
| parser.resolveOperands(funcArguments, funcType.getInputs(), |
| parser.getNameLoc(), result.operands)) |
| return failure(); |
| |
| result.addTypes(llvmResultType); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::ExtractElementOp. |
| //===----------------------------------------------------------------------===// |
| // Expects vector to be of wrapped LLVM vector type and position to be of |
| // wrapped LLVM i32 type. |
| void LLVM::ExtractElementOp::build(Builder *b, OperationState &result, |
| Value vector, Value position, |
| ArrayRef<NamedAttribute> attrs) { |
| auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>(); |
| auto llvmType = wrappedVectorType.getVectorElementType(); |
| build(b, result, llvmType, vector, position); |
| result.addAttributes(attrs); |
| } |
| |
| static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { |
| p << op.getOperationName() << ' ' << op.vector() << "[" << op.position() |
| << " : " << op.position().getType() << "]"; |
| p.printOptionalAttrDict(op.getAttrs()); |
| p << " : " << op.vector().getType(); |
| } |
| |
| // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use |
| // attribute-dict? `:` type |
| static ParseResult parseExtractElementOp(OpAsmParser &parser, |
| OperationState &result) { |
| llvm::SMLoc loc; |
| OpAsmParser::OperandType vector, position; |
| Type type, positionType; |
| if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || |
| parser.parseLSquare() || parser.parseOperand(position) || |
| parser.parseColonType(positionType) || parser.parseRSquare() || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperand(vector, type, result.operands) || |
| parser.resolveOperand(position, positionType, result.operands)) |
| return failure(); |
| auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>(); |
| if (!wrappedVectorType || |
| !wrappedVectorType.getUnderlyingType()->isVectorTy()) |
| return parser.emitError( |
| loc, "expected LLVM IR dialect vector type for operand #1"); |
| result.addTypes(wrappedVectorType.getVectorElementType()); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::ExtractValueOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { |
| p << op.getOperationName() << ' ' << op.container() << op.position(); |
| p.printOptionalAttrDict(op.getAttrs(), {"position"}); |
| p << " : " << op.container().getType(); |
| } |
| |
| // Extract the type at `position` in the wrapped LLVM IR aggregate type |
| // `containerType`. Position is an integer array attribute where each value |
| // is a zero-based position of the element in the aggregate type. Return the |
| // resulting type wrapped in MLIR, or nullptr on error. |
| static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser, |
| Type containerType, |
| ArrayAttr positionAttr, |
| llvm::SMLoc attributeLoc, |
| llvm::SMLoc typeLoc) { |
| auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>(); |
| if (!wrappedContainerType) |
| return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; |
| |
| // Infer the element type from the structure type: iteratively step inside the |
| // type by taking the element type, indexed by the position attribute for |
| // structures. Check the position index before accessing, it is supposed to |
| // be in bounds. |
| for (Attribute subAttr : positionAttr) { |
| auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); |
| if (!positionElementAttr) |
| return parser.emitError(attributeLoc, |
| "expected an array of integer literals"), |
| nullptr; |
| int position = positionElementAttr.getInt(); |
| auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); |
| if (llvmContainerType->isArrayTy()) { |
| if (position < 0 || static_cast<unsigned>(position) >= |
| llvmContainerType->getArrayNumElements()) |
| return parser.emitError(attributeLoc, "position out of bounds"), |
| nullptr; |
| wrappedContainerType = wrappedContainerType.getArrayElementType(); |
| } else if (llvmContainerType->isStructTy()) { |
| if (position < 0 || static_cast<unsigned>(position) >= |
| llvmContainerType->getStructNumElements()) |
| return parser.emitError(attributeLoc, "position out of bounds"), |
| nullptr; |
| wrappedContainerType = |
| wrappedContainerType.getStructElementType(position); |
| } else { |
| return parser.emitError(typeLoc, |
| "expected wrapped LLVM IR structure/array type"), |
| nullptr; |
| } |
| } |
| return wrappedContainerType; |
| } |
| |
| // <operation> ::= `llvm.extractvalue` ssa-use |
| // `[` integer-literal (`,` integer-literal)* `]` |
| // attribute-dict? `:` type |
| static ParseResult parseExtractValueOp(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::OperandType container; |
| Type containerType; |
| ArrayAttr positionAttr; |
| llvm::SMLoc attributeLoc, trailingTypeLoc; |
| |
| if (parser.parseOperand(container) || |
| parser.getCurrentLocation(&attributeLoc) || |
| parser.parseAttribute(positionAttr, "position", result.attributes) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || |
| parser.parseType(containerType) || |
| parser.resolveOperand(container, containerType, result.operands)) |
| return failure(); |
| |
| auto elementType = getInsertExtractValueElementType( |
| parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); |
| if (!elementType) |
| return failure(); |
| |
| result.addTypes(elementType); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::InsertElementOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { |
| p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "[" |
| << op.position() << " : " << op.position().getType() << "]"; |
| p.printOptionalAttrDict(op.getAttrs()); |
| p << " : " << op.vector().getType(); |
| } |
| |
| // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use |
| // attribute-dict? `:` type |
| static ParseResult parseInsertElementOp(OpAsmParser &parser, |
| OperationState &result) { |
| llvm::SMLoc loc; |
| OpAsmParser::OperandType vector, value, position; |
| Type vectorType, positionType; |
| if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || |
| parser.parseComma() || parser.parseOperand(vector) || |
| parser.parseLSquare() || parser.parseOperand(position) || |
| parser.parseColonType(positionType) || parser.parseRSquare() || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(vectorType)) |
| return failure(); |
| |
| auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>(); |
| if (!wrappedVectorType || |
| !wrappedVectorType.getUnderlyingType()->isVectorTy()) |
| return parser.emitError( |
| loc, "expected LLVM IR dialect vector type for operand #1"); |
| auto valueType = wrappedVectorType.getVectorElementType(); |
| if (!valueType) |
| return failure(); |
| |
| if (parser.resolveOperand(vector, vectorType, result.operands) || |
| parser.resolveOperand(value, valueType, result.operands) || |
| parser.resolveOperand(position, positionType, result.operands)) |
| return failure(); |
| |
| result.addTypes(vectorType); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::InsertValueOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { |
| p << op.getOperationName() << ' ' << op.value() << ", " << op.container() |
| << op.position(); |
| p.printOptionalAttrDict(op.getAttrs(), {"position"}); |
| p << " : " << op.container().getType(); |
| } |
| |
| // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use |
| // `[` integer-literal (`,` integer-literal)* `]` |
| // attribute-dict? `:` type |
| static ParseResult parseInsertValueOp(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::OperandType container, value; |
| Type containerType; |
| ArrayAttr positionAttr; |
| llvm::SMLoc attributeLoc, trailingTypeLoc; |
| |
| if (parser.parseOperand(value) || parser.parseComma() || |
| parser.parseOperand(container) || |
| parser.getCurrentLocation(&attributeLoc) || |
| parser.parseAttribute(positionAttr, "position", result.attributes) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || |
| parser.parseType(containerType)) |
| return failure(); |
| |
| auto valueType = getInsertExtractValueElementType( |
| parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); |
| if (!valueType) |
| return failure(); |
| |
| if (parser.resolveOperand(container, containerType, result.operands) || |
| parser.resolveOperand(value, valueType, result.operands)) |
| return failure(); |
| |
| result.addTypes(containerType); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::BrOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printBrOp(OpAsmPrinter &p, BrOp &op) { |
| p << op.getOperationName() << ' '; |
| p.printSuccessorAndUseList(op.getOperation(), 0); |
| p.printOptionalAttrDict(op.getAttrs()); |
| } |
| |
| // <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)? |
| // attribute-dict? |
| static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) { |
| Block *dest; |
| SmallVector<Value, 4> operands; |
| if (parser.parseSuccessorAndUseList(dest, operands) || |
| parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| |
| result.addSuccessor(dest, operands); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::CondBrOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) { |
| p << op.getOperationName() << ' ' << op.getOperand(0) << ", "; |
| p.printSuccessorAndUseList(op.getOperation(), 0); |
| p << ", "; |
| p.printSuccessorAndUseList(op.getOperation(), 1); |
| p.printOptionalAttrDict(op.getAttrs()); |
| } |
| |
| // <operation> ::= `llvm.cond_br` ssa-use `,` |
| // bb-id (`[` ssa-use-and-type-list `]`)? `,` |
| // bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict? |
| static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) { |
| Block *trueDest; |
| Block *falseDest; |
| SmallVector<Value, 4> trueOperands; |
| SmallVector<Value, 4> falseOperands; |
| OpAsmParser::OperandType condition; |
| |
| Builder &builder = parser.getBuilder(); |
| auto *llvmDialect = |
| builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); |
| auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect); |
| |
| if (parser.parseOperand(condition) || parser.parseComma() || |
| parser.parseSuccessorAndUseList(trueDest, trueOperands) || |
| parser.parseComma() || |
| parser.parseSuccessorAndUseList(falseDest, falseOperands) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.resolveOperand(condition, i1Type, result.operands)) |
| return failure(); |
| |
| result.addSuccessor(trueDest, trueOperands); |
| result.addSuccessor(falseDest, falseOperands); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::ReturnOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) { |
| p << op.getOperationName(); |
| p.printOptionalAttrDict(op.getAttrs()); |
| assert(op.getNumOperands() <= 1); |
| |
| if (op.getNumOperands() == 0) |
| return; |
| |
| p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); |
| } |
| |
| // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` |
| // type-list-no-parens |
| static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 1> operands; |
| Type type; |
| |
| if (parser.parseOperandList(operands) || |
| parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| if (operands.empty()) |
| return success(); |
| |
| if (parser.parseColonType(type) || |
| parser.resolveOperand(operands[0], type, result.operands)) |
| return failure(); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Verifier for LLVM::AddressOfOp. |
| //===----------------------------------------------------------------------===// |
| |
| GlobalOp AddressOfOp::getGlobal() { |
| Operation *module = getParentOp(); |
| while (module && !satisfiesLLVMModule(module)) |
| module = module->getParentOp(); |
| assert(module && "unexpected operation outside of a module"); |
| return dyn_cast_or_null<LLVM::GlobalOp>( |
| mlir::SymbolTable::lookupSymbolIn(module, global_name())); |
| } |
| |
| static LogicalResult verify(AddressOfOp op) { |
| auto global = op.getGlobal(); |
| if (!global) |
| return op.emitOpError( |
| "must reference a global defined by 'llvm.mlir.global'"); |
| |
| if (global.getType().getPointerTo(global.addr_space().getZExtValue()) != |
| op.getResult().getType()) |
| return op.emitOpError( |
| "the type must be a pointer to the type of the referred global"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Builder, printer and verifier for LLVM::GlobalOp. |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns the name used for the linkage attribute. This *must* correspond to |
| /// the name of the attribute in ODS. |
| static StringRef getLinkageAttrName() { return "linkage"; } |
| |
| void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type, |
| bool isConstant, Linkage linkage, StringRef name, |
| Attribute value, unsigned addrSpace, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder->getStringAttr(name)); |
| result.addAttribute("type", TypeAttr::get(type)); |
| if (isConstant) |
| result.addAttribute("constant", builder->getUnitAttr()); |
| if (value) |
| result.addAttribute("value", value); |
| result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr( |
| static_cast<int64_t>(linkage))); |
| if (addrSpace != 0) |
| result.addAttribute("addr_space", builder->getI32IntegerAttr(addrSpace)); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| result.addRegion(); |
| } |
| |
| static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { |
| p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' '; |
| if (op.constant()) |
| p << "constant "; |
| p.printSymbolName(op.sym_name()); |
| p << '('; |
| if (auto value = op.getValueOrNull()) |
| p.printAttribute(value); |
| p << ')'; |
| p.printOptionalAttrDict(op.getAttrs(), |
| {SymbolTable::getSymbolAttrName(), "type", "constant", |
| "value", getLinkageAttrName()}); |
| |
| // Print the trailing type unless it's a string global. |
| if (op.getValueOrNull().dyn_cast_or_null<StringAttr>()) |
| return; |
| p << " : " << op.type(); |
| |
| Region &initializer = op.getInitializerRegion(); |
| if (!initializer.empty()) |
| p.printRegion(initializer, /*printEntryBlockArgs=*/false); |
| } |
| |
| // Parses one of the keywords provided in the list `keywords` and returns the |
| // position of the parsed keyword in the list. If none of the keywords from the |
| // list is parsed, returns -1. |
| static int parseOptionalKeywordAlternative(OpAsmParser &parser, |
| ArrayRef<StringRef> keywords) { |
| for (auto en : llvm::enumerate(keywords)) { |
| if (succeeded(parser.parseOptionalKeyword(en.value()))) |
| return en.index(); |
| } |
| return -1; |
| } |
| |
| namespace { |
| template <typename Ty> struct EnumTraits {}; |
| |
| #define REGISTER_ENUM_TYPE(Ty) \ |
| template <> struct EnumTraits<Ty> { \ |
| static StringRef stringify(Ty value) { return stringify##Ty(value); } \ |
| static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ |
| } |
| |
| REGISTER_ENUM_TYPE(Linkage); |
| } // end namespace |
| |
| template <typename EnumTy> |
| static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser, |
| OperationState &result, |
| StringRef name) { |
| SmallVector<StringRef, 10> names; |
| for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i) |
| names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); |
| |
| int index = parseOptionalKeywordAlternative(parser, names); |
| if (index == -1) |
| return failure(); |
| result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index)); |
| return success(); |
| } |
| |
| // operation ::= `llvm.mlir.global` linkage `constant`? `@` identifier |
| // `(` attribute? `)` attribute-list? (`:` type)? region? |
| // |
| // The type can be omitted for string attributes, in which case it will be |
| // inferred from the value of the string as [strlen(value) x i8]. |
| static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { |
| if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, |
| getLinkageAttrName()))) |
| return parser.emitError(parser.getCurrentLocation(), "expected linkage"); |
| |
| if (succeeded(parser.parseOptionalKeyword("constant"))) |
| result.addAttribute("constant", parser.getBuilder().getUnitAttr()); |
| |
| StringAttr name; |
| if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), |
| result.attributes) || |
| parser.parseLParen()) |
| return failure(); |
| |
| Attribute value; |
| if (parser.parseOptionalRParen()) { |
| if (parser.parseAttribute(value, "value", result.attributes) || |
| parser.parseRParen()) |
| return failure(); |
| } |
| |
| SmallVector<Type, 1> types; |
| if (parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseOptionalColonTypeList(types)) |
| return failure(); |
| |
| if (types.size() > 1) |
| return parser.emitError(parser.getNameLoc(), "expected zero or one type"); |
| |
| Region &initRegion = *result.addRegion(); |
| if (types.empty()) { |
| if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { |
| MLIRContext *context = parser.getBuilder().getContext(); |
| auto *dialect = context->getRegisteredDialect<LLVMDialect>(); |
| auto arrayType = LLVM::LLVMType::getArrayTy( |
| LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size()); |
| types.push_back(arrayType); |
| } else { |
| return parser.emitError(parser.getNameLoc(), |
| "type can only be omitted for string globals"); |
| } |
| } else if (parser.parseOptionalRegion(initRegion, /*arguments=*/{}, |
| /*argTypes=*/{})) { |
| return failure(); |
| } |
| |
| result.addAttribute("type", TypeAttr::get(types[0])); |
| return success(); |
| } |
| |
| static LogicalResult verify(GlobalOp op) { |
| if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) |
| return op.emitOpError( |
| "expects type to be a valid element type for an LLVM pointer"); |
| if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp())) |
| return op.emitOpError("must appear at the module level"); |
| |
| if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { |
| auto type = op.getType(); |
| if (!type.getUnderlyingType()->isArrayTy() || |
| !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) || |
| type.getArrayNumElements() != strAttr.getValue().size()) |
| return op.emitOpError( |
| "requires an i8 array type of the length equal to that of the string " |
| "attribute"); |
| } |
| |
| if (Block *b = op.getInitializerBlock()) { |
| ReturnOp ret = cast<ReturnOp>(b->getTerminator()); |
| if (ret.operand_type_begin() == ret.operand_type_end()) |
| return op.emitOpError("initializer region cannot return void"); |
| if (*ret.operand_type_begin() != op.getType()) |
| return op.emitOpError("initializer region type ") |
| << *ret.operand_type_begin() << " does not match global type " |
| << op.getType(); |
| |
| if (op.getValueOrNull()) |
| return op.emitOpError("cannot have both initializer value and region"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing/parsing for LLVM::ShuffleVectorOp. |
| //===----------------------------------------------------------------------===// |
| // Expects vector to be of wrapped LLVM vector type and position to be of |
| // wrapped LLVM i32 type. |
| void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1, |
| Value v2, ArrayAttr mask, |
| ArrayRef<NamedAttribute> attrs) { |
| auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>(); |
| auto vType = LLVMType::getVectorTy( |
| wrappedContainerType1.getVectorElementType(), mask.size()); |
| build(b, result, vType, v1, v2, mask); |
| result.addAttributes(attrs); |
| } |
| |
| static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { |
| p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " " |
| << op.mask(); |
| p.printOptionalAttrDict(op.getAttrs(), {"mask"}); |
| p << " : " << op.v1().getType() << ", " << op.v2().getType(); |
| } |
| |
| // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use |
| // `[` integer-literal (`,` integer-literal)* `]` |
| // attribute-dict? `:` type |
| static ParseResult parseShuffleVectorOp(OpAsmParser &parser, |
| OperationState &result) { |
| llvm::SMLoc loc; |
| OpAsmParser::OperandType v1, v2; |
| ArrayAttr maskAttr; |
| Type typeV1, typeV2; |
| if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || |
| parser.parseComma() || parser.parseOperand(v2) || |
| parser.parseAttribute(maskAttr, "mask", result.attributes) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(typeV1) || parser.parseComma() || |
| parser.parseType(typeV2) || |
| parser.resolveOperand(v1, typeV1, result.operands) || |
| parser.resolveOperand(v2, typeV2, result.operands)) |
| return failure(); |
| auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>(); |
| if (!wrappedContainerType1 || |
| !wrappedContainerType1.getUnderlyingType()->isVectorTy()) |
| return parser.emitError( |
| loc, "expected LLVM IR dialect vector type for operand #1"); |
| auto vType = LLVMType::getVectorTy( |
| wrappedContainerType1.getVectorElementType(), maskAttr.size()); |
| result.addTypes(vType); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Implementations for LLVM::LLVMFuncOp. |
| //===----------------------------------------------------------------------===// |
| |
| // Add the entry block to the function. |
| Block *LLVMFuncOp::addEntryBlock() { |
| assert(empty() && "function already has an entry block"); |
| assert(!isVarArg() && "unimplemented: non-external variadic functions"); |
| |
| auto *entry = new Block; |
| push_back(entry); |
| |
| LLVMType type = getType(); |
| for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i) |
| entry->addArgument(type.getFunctionParamType(i)); |
| return entry; |
| } |
| |
| void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name, |
| LLVMType type, LLVM::Linkage linkage, |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<NamedAttributeList> argAttrs) { |
| result.addRegion(); |
| result.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder->getStringAttr(name)); |
| result.addAttribute("type", TypeAttr::get(type)); |
| result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr( |
| static_cast<int64_t>(linkage))); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| if (argAttrs.empty()) |
| return; |
| |
| unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams(); |
| assert(numInputs == argAttrs.size() && |
| "expected as many argument attribute lists as arguments"); |
| SmallString<8> argAttrName; |
| for (unsigned i = 0; i < numInputs; ++i) |
| if (auto argDict = argAttrs[i].getDictionary()) |
| result.addAttribute(getArgAttrName(i, argAttrName), argDict); |
| } |
| |
| // Builds an LLVM function type from the given lists of input and output types. |
| // Returns a null type if any of the types provided are non-LLVM types, or if |
| // there is more than one output type. |
| static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, |
| ArrayRef<Type> inputs, ArrayRef<Type> outputs, |
| impl::VariadicFlag variadicFlag) { |
| Builder &b = parser.getBuilder(); |
| if (outputs.size() > 1) { |
| parser.emitError(loc, "failed to construct function type: expected zero or " |
| "one function result"); |
| return {}; |
| } |
| |
| // Convert inputs to LLVM types, exit early on error. |
| SmallVector<LLVMType, 4> llvmInputs; |
| for (auto t : inputs) { |
| auto llvmTy = t.dyn_cast<LLVMType>(); |
| if (!llvmTy) { |
| parser.emitError(loc, "failed to construct function type: expected LLVM " |
| "type for function arguments"); |
| return {}; |
| } |
| llvmInputs.push_back(llvmTy); |
| } |
| |
| // Get the dialect from the input type, if any exist. Look it up in the |
| // context otherwise. |
| LLVMDialect *dialect = |
| llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>() |
| : &llvmInputs.front().getDialect(); |
| |
| // No output is denoted as "void" in LLVM type system. |
| LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) |
| : outputs.front().dyn_cast<LLVMType>(); |
| if (!llvmOutput) { |
| parser.emitError(loc, "failed to construct function type: expected LLVM " |
| "type for function results"); |
| return {}; |
| } |
| return LLVMType::getFunctionTy(llvmOutput, llvmInputs, |
| variadicFlag.isVariadic()); |
| } |
| |
| // Parses an LLVM function. |
| // |
| // operation ::= `llvm.func` linkage? function-signature function-attributes? |
| // function-body |
| // |
| static ParseResult parseLLVMFuncOp(OpAsmParser &parser, |
| OperationState &result) { |
| // Default to external linkage if no keyword is provided. |
| if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, |
| getLinkageAttrName()))) |
| result.addAttribute(getLinkageAttrName(), |
| parser.getBuilder().getI64IntegerAttr( |
| static_cast<int64_t>(LLVM::Linkage::External))); |
| |
| StringAttr nameAttr; |
| SmallVector<OpAsmParser::OperandType, 8> entryArgs; |
| SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs; |
| SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs; |
| SmallVector<Type, 8> argTypes; |
| SmallVector<Type, 4> resultTypes; |
| bool isVariadic; |
| |
| auto signatureLocation = parser.getCurrentLocation(); |
| if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
| result.attributes) || |
| impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs, |
| argTypes, argAttrs, isVariadic, resultTypes, |
| resultAttrs)) |
| return failure(); |
| |
| auto type = |
| buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, |
| impl::VariadicFlag(isVariadic)); |
| if (!type) |
| return failure(); |
| result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type)); |
| |
| if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) |
| return failure(); |
| impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs, |
| resultAttrs); |
| |
| auto *body = result.addRegion(); |
| return parser.parseOptionalRegion( |
| *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); |
| } |
| |
| // Print the LLVMFuncOp. Collects argument and result types and passes them to |
| // helper functions. Drops "void" result since it cannot be parsed back. Skips |
| // the external linkage since it is the default value. |
| static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { |
| p << op.getOperationName() << ' '; |
| if (op.linkage() != LLVM::Linkage::External) |
| p << stringifyLinkage(op.linkage()) << ' '; |
| p.printSymbolName(op.getName()); |
| |
| LLVMType fnType = op.getType(); |
| SmallVector<Type, 8> argTypes; |
| SmallVector<Type, 1> resTypes; |
| argTypes.reserve(fnType.getFunctionNumParams()); |
| for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) |
| argTypes.push_back(fnType.getFunctionParamType(i)); |
| |
| LLVMType returnType = fnType.getFunctionResultType(); |
| if (!returnType.isVoidTy()) |
| resTypes.push_back(returnType); |
| |
| impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes); |
| impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(), |
| {getLinkageAttrName()}); |
| |
| // Print the body if this is not an external function. |
| Region &body = op.body(); |
| if (!body.empty()) |
| p.printRegion(body, /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| |
| // Hook for OpTrait::FunctionLike, called after verifying that the 'type' |
| // attribute is present. This can check for preconditions of the |
| // getNumArguments hook not failing. |
| LogicalResult LLVMFuncOp::verifyType() { |
| auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>(); |
| if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy()) |
| return emitOpError("requires '" + getTypeAttrName() + |
| "' attribute of wrapped LLVM function type"); |
| |
| return success(); |
| } |
| |
| // Hook for OpTrait::FunctionLike, returns the number of function arguments. |
| // Depends on the type attribute being correct as checked by verifyType |
| unsigned LLVMFuncOp::getNumFuncArguments() { |
| return getType().getUnderlyingType()->getFunctionNumParams(); |
| } |
| |
| // Hook for OpTrait::FunctionLike, returns the number of function results. |
| // Depends on the type attribute being correct as checked by verifyType |
| unsigned LLVMFuncOp::getNumFuncResults() { |
| // We model LLVM functions that return void as having zero results, |
| // and all others as having one result. |
| // If we modeled a void return as one result, then it would be possible to |
| // attach an MLIR result attribute to it, and it isn't clear what semantics we |
| // would assign to that. |
| if (getType().getFunctionResultType().isVoidTy()) |
| return 0; |
| return 1; |
| } |
| |
| // Verifies LLVM- and implementation-specific properties of the LLVM func Op: |
| // - functions don't have 'common' linkage |
| // - external functions have 'external' or 'extern_weak' linkage; |
| // - vararg is (currently) only supported for external functions; |
| // - entry block arguments are of LLVM types and match the function signature. |
| static LogicalResult verify(LLVMFuncOp op) { |
| if (op.linkage() == LLVM::Linkage::Common) |
| return op.emitOpError() |
| << "functions cannot have '" |
| << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; |
| |
| if (op.isExternal()) { |
| if (op.linkage() != LLVM::Linkage::External && |
| op.linkage() != LLVM::Linkage::ExternWeak) |
| return op.emitOpError() |
| << "external functions must have '" |
| << stringifyLinkage(LLVM::Linkage::External) << "' or '" |
| << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; |
| return success(); |
| } |
| |
| if (op.isVarArg()) |
| return op.emitOpError("only external functions can be variadic"); |
| |
| auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType()); |
| unsigned numArguments = funcType->getNumParams(); |
| Block &entryBlock = op.front(); |
| for (unsigned i = 0; i < numArguments; ++i) { |
| Type argType = entryBlock.getArgument(i).getType(); |
| auto argLLVMType = argType.dyn_cast<LLVMType>(); |
| if (!argLLVMType) |
| return op.emitOpError("entry block argument #") |
| << i << " is not of LLVM type"; |
| if (funcType->getParamType(i) != argLLVMType.getUnderlyingType()) |
| return op.emitOpError("the type of entry block argument #") |
| << i << " does not match the function signature"; |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Verification for LLVM::NullOp. |
| //===----------------------------------------------------------------------===// |
| |
| // Only LLVM pointer types are supported. |
| static LogicalResult verify(LLVM::NullOp op) { |
| auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>(); |
| if (!llvmType || !llvmType.isPointerTy()) |
| return op.emitOpError("expected LLVM IR pointer type"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Utility functions for parsing atomic ops |
| //===----------------------------------------------------------------------===// |
| |
| // Helper function to parse a keyword into the specified attribute named by |
| // `attrName`. The keyword must match one of the string values defined by the |
| // AtomicBinOp enum. The resulting I64 attribute is added to the `result` |
| // state. |
| static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, |
| StringRef attrName) { |
| llvm::SMLoc loc; |
| StringRef keyword; |
| if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) |
| return failure(); |
| |
| // Replace the keyword `keyword` with an integer attribute. |
| auto kind = symbolizeAtomicBinOp(keyword); |
| if (!kind) { |
| return parser.emitError(loc) |
| << "'" << keyword << "' is an incorrect value of the '" << attrName |
| << "' attribute"; |
| } |
| |
| auto value = static_cast<int64_t>(kind.getValue()); |
| auto attr = parser.getBuilder().getI64IntegerAttr(value); |
| result.addAttribute(attrName, attr); |
| |
| return success(); |
| } |
| |
| // Helper function to parse a keyword into the specified attribute named by |
| // `attrName`. The keyword must match one of the string values defined by the |
| // AtomicOrdering enum. The resulting I64 attribute is added to the `result` |
| // state. |
| static ParseResult parseAtomicOrdering(OpAsmParser &parser, |
| OperationState &result, |
| StringRef attrName) { |
| llvm::SMLoc loc; |
| StringRef ordering; |
| if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) |
| return failure(); |
| |
| // Replace the keyword `ordering` with an integer attribute. |
| auto kind = symbolizeAtomicOrdering(ordering); |
| if (!kind) { |
| return parser.emitError(loc) |
| << "'" << ordering << "' is an incorrect value of the '" << attrName |
| << "' attribute"; |
| } |
| |
| auto value = static_cast<int64_t>(kind.getValue()); |
| auto attr = parser.getBuilder().getI64IntegerAttr(value); |
| result.addAttribute(attrName, attr); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printer, parser and verifier for LLVM::AtomicRMWOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { |
| p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' ' |
| << op.ptr() << ", " << op.val() << ' ' |
| << stringifyAtomicOrdering(op.ordering()) << ' '; |
| p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"}); |
| p << " : " << op.res().getType(); |
| } |
| |
| // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword |
| // attribute-dict? `:` type |
| static ParseResult parseAtomicRMWOp(OpAsmParser &parser, |
| OperationState &result) { |
| LLVMType type; |
| OpAsmParser::OperandType ptr, val; |
| if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || |
| parser.parseComma() || parser.parseOperand(val) || |
| parseAtomicOrdering(parser, result, "ordering") || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperand(ptr, type.getPointerTo(), result.operands) || |
| parser.resolveOperand(val, type, result.operands)) |
| return failure(); |
| |
| result.addTypes(type); |
| return success(); |
| } |
| |
| static LogicalResult verify(AtomicRMWOp op) { |
| auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>(); |
| if (!ptrType.isPointerTy()) |
| return op.emitOpError("expected LLVM IR pointer type for operand #0"); |
| auto valType = op.val().getType().cast<LLVM::LLVMType>(); |
| if (valType != ptrType.getPointerElementTy()) |
| return op.emitOpError("expected LLVM IR element type for operand #0 to " |
| "match type for operand #1"); |
| auto resType = op.res().getType().cast<LLVM::LLVMType>(); |
| if (resType != valType) |
| return op.emitOpError( |
| "expected LLVM IR result type to match type for operand #1"); |
| if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { |
| if (!valType.getUnderlyingType()->isFloatingPointTy()) |
| return op.emitOpError("expected LLVM IR floating point type"); |
| } else if (op.bin_op() == AtomicBinOp::xchg) { |
| if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && |
| !valType.isIntegerTy(32) && !valType.isIntegerTy(64) && |
| !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy()) |
| return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); |
| } else { |
| if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && |
| !valType.isIntegerTy(32) && !valType.isIntegerTy(64)) |
| return op.emitOpError("expected LLVM IR integer type"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. |
| //===----------------------------------------------------------------------===// |
| |
| static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { |
| p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", " |
| << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' ' |
| << stringifyAtomicOrdering(op.failure_ordering()); |
| p.printOptionalAttrDict(op.getAttrs(), |
| {"success_ordering", "failure_ordering"}); |
| p << " : " << op.val().getType(); |
| } |
| |
| // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use |
| // keyword keyword attribute-dict? `:` type |
| static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| LLVMType type; |
| OpAsmParser::OperandType ptr, cmp, val; |
| if (parser.parseOperand(ptr) || parser.parseComma() || |
| parser.parseOperand(cmp) || parser.parseComma() || |
| parser.parseOperand(val) || |
| parseAtomicOrdering(parser, result, "success_ordering") || |
| parseAtomicOrdering(parser, result, "failure_ordering") || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperand(ptr, type.getPointerTo(), result.operands) || |
| parser.resolveOperand(cmp, type, result.operands) || |
| parser.resolveOperand(val, type, result.operands)) |
| return failure(); |
| |
| auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>(); |
| auto boolType = LLVMType::getInt1Ty(dialect); |
| auto resultType = LLVMType::getStructTy(type, boolType); |
| result.addTypes(resultType); |
| |
| return success(); |
| } |
| |
| static LogicalResult verify(AtomicCmpXchgOp op) { |
| auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>(); |
| if (!ptrType.isPointerTy()) |
| return op.emitOpError("expected LLVM IR pointer type for operand #0"); |
| auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>(); |
| auto valType = op.val().getType().cast<LLVM::LLVMType>(); |
| if (cmpType != ptrType.getPointerElementTy() || cmpType != valType) |
| return op.emitOpError("expected LLVM IR element type for operand #0 to " |
| "match type for all other operands"); |
| if (!valType.isPointerTy() && !valType.isIntegerTy(8) && |
| !valType.isIntegerTy(16) && !valType.isIntegerTy(32) && |
| !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() && |
| !valType.isDoubleTy()) |
| return op.emitOpError("unexpected LLVM IR type"); |
| if (op.success_ordering() < AtomicOrdering::monotonic || |
| op.failure_ordering() < AtomicOrdering::monotonic) |
| return op.emitOpError("ordering must be at least 'monotonic'"); |
| if (op.failure_ordering() == AtomicOrdering::release || |
| op.failure_ordering() == AtomicOrdering::acq_rel) |
| return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVMDialect initialization, type parsing, and registration. |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| namespace LLVM { |
| namespace detail { |
| struct LLVMDialectImpl { |
| LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} |
| |
| llvm::LLVMContext llvmContext; |
| llvm::Module module; |
| |
| /// A set of LLVMTypes that are cached on construction to avoid any lookups or |
| /// locking. |
| LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; |
| LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty; |
| LLVMType voidTy; |
| |
| /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not |
| /// multi-threaded and requires locked access to prevent race conditions. |
| llvm::sys::SmartMutex<true> mutex; |
| }; |
| } // end namespace detail |
| } // end namespace LLVM |
| } // end namespace mlir |
| |
| LLVMDialect::LLVMDialect(MLIRContext *context) |
| : Dialect(getDialectNamespace(), context), |
| impl(new detail::LLVMDialectImpl()) { |
| addTypes<LLVMType>(); |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" |
| >(); |
| |
| // Support unknown operations because not all LLVM operations are registered. |
| allowUnknownOperations(); |
| |
| // Cache some of the common LLVM types to avoid the need for lookups/locking. |
| auto &llvmContext = impl->llvmContext; |
| /// Integer Types. |
| impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext)); |
| impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext)); |
| impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext)); |
| impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext)); |
| impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext)); |
| impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext)); |
| /// Float Types. |
| impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext)); |
| impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext)); |
| impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext)); |
| impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext)); |
| impl->x86_fp80Ty = |
| LLVMType::get(context, llvm::Type::getX86_FP80Ty(llvmContext)); |
| /// Other Types. |
| impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext)); |
| } |
| |
| LLVMDialect::~LLVMDialect() {} |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" |
| |
| llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } |
| llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } |
| |
| /// Parse a type registered to this dialect. |
| Type LLVMDialect::parseType(DialectAsmParser &parser) const { |
| StringRef tyData = parser.getFullSymbolSpec(); |
| |
| // LLVM is not thread-safe, so lock access to it. |
| llvm::sys::SmartScopedLock<true> lock(impl->mutex); |
| |
| llvm::SMDiagnostic errorMessage; |
| llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); |
| if (!type) |
| return (parser.emitError(parser.getNameLoc(), errorMessage.getMessage()), |
| nullptr); |
| return LLVMType::get(getContext(), type); |
| } |
| |
| /// Print a type registered to this dialect. |
| void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { |
| auto llvmType = type.dyn_cast<LLVMType>(); |
| assert(llvmType && "printing wrong type"); |
| assert(llvmType.getUnderlyingType() && "no underlying LLVM type"); |
| llvmType.getUnderlyingType()->print(os.getStream()); |
| } |
| |
| /// Verify LLVMIR function argument attributes. |
| LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, |
| unsigned regionIdx, |
| unsigned argIdx, |
| NamedAttribute argAttr) { |
| // Check that llvm.noalias is a boolean attribute. |
| if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>()) |
| return op->emitError() |
| << "llvm.noalias argument attribute of non boolean type"; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVMType. |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| namespace LLVM { |
| namespace detail { |
| struct LLVMTypeStorage : public ::mlir::TypeStorage { |
| LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {} |
| |
| // LLVM types are pointer-unique. |
| using KeyTy = llvm::Type *; |
| bool operator==(const KeyTy &key) const { return key == underlyingType; } |
| |
| static LLVMTypeStorage *construct(TypeStorageAllocator &allocator, |
| llvm::Type *ty) { |
| return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty); |
| } |
| |
| llvm::Type *underlyingType; |
| }; |
| } // end namespace detail |
| } // end namespace LLVM |
| } // end namespace mlir |
| |
| LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { |
| return Base::get(context, FIRST_LLVM_TYPE, llvmType); |
| } |
| |
| /// Get an LLVMType with an llvm type that may cause changes to the underlying |
| /// llvm context when constructed. |
| LLVMType LLVMType::getLocked(LLVMDialect *dialect, |
| function_ref<llvm::Type *()> typeBuilder) { |
| // Lock access to the llvm context and build the type. |
| llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex); |
| return get(dialect->getContext(), typeBuilder()); |
| } |
| |
| LLVMDialect &LLVMType::getDialect() { |
| return static_cast<LLVMDialect &>(Type::getDialect()); |
| } |
| |
| llvm::Type *LLVMType::getUnderlyingType() const { |
| return getImpl()->underlyingType; |
| } |
| |
| /// Array type utilities. |
| LLVMType LLVMType::getArrayElementType() { |
| return get(getContext(), getUnderlyingType()->getArrayElementType()); |
| } |
| unsigned LLVMType::getArrayNumElements() { |
| return getUnderlyingType()->getArrayNumElements(); |
| } |
| bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); } |
| |
| /// Vector type utilities. |
| LLVMType LLVMType::getVectorElementType() { |
| return get(getContext(), getUnderlyingType()->getVectorElementType()); |
| } |
| bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } |
| |
| /// Function type utilities. |
| LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { |
| return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); |
| } |
| unsigned LLVMType::getFunctionNumParams() { |
| return getUnderlyingType()->getFunctionNumParams(); |
| } |
| LLVMType LLVMType::getFunctionResultType() { |
| return get( |
| getContext(), |
| llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType()); |
| } |
| bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); } |
| |
| /// Pointer type utilities. |
| LLVMType LLVMType::getPointerTo(unsigned addrSpace) { |
| // Lock access to the dialect as this may modify the LLVM context. |
| return getLocked(&getDialect(), [=] { |
| return getUnderlyingType()->getPointerTo(addrSpace); |
| }); |
| } |
| LLVMType LLVMType::getPointerElementTy() { |
| return get(getContext(), getUnderlyingType()->getPointerElementType()); |
| } |
| bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } |
| |
| /// Struct type utilities. |
| LLVMType LLVMType::getStructElementType(unsigned i) { |
| return get(getContext(), getUnderlyingType()->getStructElementType(i)); |
| } |
| unsigned LLVMType::getStructNumElements() { |
| return getUnderlyingType()->getStructNumElements(); |
| } |
| bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); } |
| |
| /// Utilities used to generate floating point types. |
| LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { |
| return dialect->impl->doubleTy; |
| } |
| LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { |
| return dialect->impl->floatTy; |
| } |
| LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { |
| return dialect->impl->halfTy; |
| } |
| LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) { |
| return dialect->impl->fp128Ty; |
| } |
| LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) { |
| return dialect->impl->x86_fp80Ty; |
| } |
| |
| /// Utilities used to generate integer types. |
| LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) { |
| switch (numBits) { |
| case 1: |
| return dialect->impl->int1Ty; |
| case 8: |
| return dialect->impl->int8Ty; |
| case 16: |
| return dialect->impl->int16Ty; |
| case 32: |
| return dialect->impl->int32Ty; |
| case 64: |
| return dialect->impl->int64Ty; |
| case 128: |
| return dialect->impl->int128Ty; |
| default: |
| break; |
| } |
| |
| // Lock access to the dialect as this may modify the LLVM context. |
| return getLocked(dialect, [=] { |
| return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits); |
| }); |
| } |
| |
| /// Utilities used to generate other miscellaneous types. |
| LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { |
| // Lock access to the dialect as this may modify the LLVM context. |
| return getLocked(&elementType.getDialect(), [=] { |
| return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements); |
| }); |
| } |
| LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params, |
| bool isVarArg) { |
| SmallVector<llvm::Type *, 8> llvmParams; |
| for (auto param : params) |
| llvmParams.push_back(param.getUnderlyingType()); |
| |
| // Lock access to the dialect as this may modify the LLVM context. |
| return getLocked(&result.getDialect(), [=] { |
| return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams, |
| isVarArg); |
| }); |
| } |
| LLVMType LLVMType::getStructTy(LLVMDialect *dialect, |
| ArrayRef<LLVMType> elements, bool isPacked) { |
| SmallVector<llvm::Type *, 8> llvmElements; |
| for (auto elt : elements) |
| llvmElements.push_back(elt.getUnderlyingType()); |
| |
| // Lock access to the dialect as this may modify the LLVM context. |
| return getLocked(dialect, [=] { |
| return llvm::StructType::get(dialect->getLLVMContext(), llvmElements, |
| isPacked); |
| }); |
| } |
| inline static SmallVector<llvm::Type *, 8> |
| toUnderlyingTypes(ArrayRef<LLVMType> elements) { |
| SmallVector<llvm::Type *, 8> llvmElements; |
| for (auto elt : elements) |
| llvmElements.push_back(elt.getUnderlyingType()); |
| return llvmElements; |
| } |
| LLVMType LLVMType::createStructTy(LLVMDialect *dialect, |
| ArrayRef<LLVMType> elements, |
| Optional<StringRef> name, bool isPacked) { |
| StringRef sr = name.hasValue() ? *name : ""; |
| SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements)); |
| return getLocked(dialect, [=] { |
| auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr); |
| if (!llvmElements.empty()) |
| rv->setBody(llvmElements, isPacked); |
| return rv; |
| }); |
| } |
| LLVMType LLVMType::setStructTyBody(LLVMType structType, |
| ArrayRef<LLVMType> elements, bool isPacked) { |
| llvm::StructType *st = |
| llvm::cast<llvm::StructType>(structType.getUnderlyingType()); |
| SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements)); |
| return getLocked(&structType.getDialect(), [=] { |
| st->setBody(llvmElements, isPacked); |
| return st; |
| }); |
| } |
| LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { |
| // Lock access to the dialect as this may modify the LLVM context. |
| return getLocked(&elementType.getDialect(), [=] { |
| return llvm::VectorType::get(elementType.getUnderlyingType(), numElements); |
| }); |
| } |
| |
| LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { |
| return dialect->impl->voidTy; |
| } |
| |
| bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); } |
| |
| //===----------------------------------------------------------------------===// |
| // Utility functions. |
| //===----------------------------------------------------------------------===// |
| |
| Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, |
| StringRef name, StringRef value, |
| LLVM::Linkage linkage, |
| LLVM::LLVMDialect *llvmDialect) { |
| assert(builder.getInsertionBlock() && |
| builder.getInsertionBlock()->getParentOp() && |
| "expected builder to point to a block constrained in an op"); |
| auto module = |
| builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); |
| assert(module && "builder points to an op outside of a module"); |
| |
| // Create the global at the entry of the module. |
| OpBuilder moduleBuilder(module.getBodyRegion()); |
| auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect), |
| value.size()); |
| auto global = moduleBuilder.create<LLVM::GlobalOp>( |
| loc, type, /*isConstant=*/true, linkage, name, |
| builder.getStringAttr(value)); |
| |
| // Get the pointer to the first character in the global string. |
| Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); |
| Value cst0 = builder.create<LLVM::ConstantOp>( |
| loc, LLVM::LLVMType::getInt64Ty(llvmDialect), |
| builder.getIntegerAttr(builder.getIndexType(), 0)); |
| return builder.create<LLVM::GEPOp>(loc, |
| LLVM::LLVMType::getInt8PtrTy(llvmDialect), |
| globalPtr, ArrayRef<Value>({cst0, cst0})); |
| } |
| |
| bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { |
| return op->hasTrait<OpTrait::SymbolTable>() && |
| op->hasTrait<OpTrait::IsIsolatedFromAbove>(); |
| } |