| //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "llvm/ADT/SetVector.h" |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| |
| //===----------------------------------------------------------------------===// |
| // Printing. |
| //===----------------------------------------------------------------------===// |
| |
| static void printTypeImpl(llvm::raw_ostream &os, LLVMType type, |
| llvm::SetVector<StringRef> &stack); |
| |
| /// Returns the keyword to use for the given type. |
| static StringRef getTypeKeyword(LLVMType type) { |
| switch (type.getKind()) { |
| case LLVMType::VoidType: |
| return "void"; |
| case LLVMType::HalfType: |
| return "half"; |
| case LLVMType::BFloatType: |
| return "bfloat"; |
| case LLVMType::FloatType: |
| return "float"; |
| case LLVMType::DoubleType: |
| return "double"; |
| case LLVMType::FP128Type: |
| return "fp128"; |
| case LLVMType::X86FP80Type: |
| return "x86_fp80"; |
| case LLVMType::PPCFP128Type: |
| return "ppc_fp128"; |
| case LLVMType::X86MMXType: |
| return "x86_mmx"; |
| case LLVMType::TokenType: |
| return "token"; |
| case LLVMType::LabelType: |
| return "label"; |
| case LLVMType::MetadataType: |
| return "metadata"; |
| case LLVMType::FunctionType: |
| return "func"; |
| case LLVMType::IntegerType: |
| return "i"; |
| case LLVMType::PointerType: |
| return "ptr"; |
| case LLVMType::FixedVectorType: |
| case LLVMType::ScalableVectorType: |
| return "vec"; |
| case LLVMType::ArrayType: |
| return "array"; |
| case LLVMType::StructType: |
| return "struct"; |
| } |
| llvm_unreachable("unhandled type kind"); |
| } |
| |
| /// Prints the body of a structure type. Uses `stack` to avoid printing |
| /// recursive structs indefinitely. |
| static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type, |
| llvm::SetVector<StringRef> &stack) { |
| if (type.isIdentified() && type.isOpaque()) { |
| os << "opaque"; |
| return; |
| } |
| |
| if (type.isPacked()) |
| os << "packed "; |
| |
| // Put the current type on stack to avoid infinite recursion. |
| os << '('; |
| if (type.isIdentified()) |
| stack.insert(type.getName()); |
| llvm::interleaveComma(type.getBody(), os, [&](LLVMType subtype) { |
| printTypeImpl(os, subtype, stack); |
| }); |
| if (type.isIdentified()) |
| stack.pop_back(); |
| os << ')'; |
| } |
| |
| /// Prints a structure type. Uses `stack` to keep track of the identifiers of |
| /// the structs being printed. Checks if the identifier of a struct is contained |
| /// in `stack`, i.e. whether a self-reference to a recursive stack is being |
| /// printed, and only prints the name to avoid infinite recursion. |
| static void printStructType(llvm::raw_ostream &os, LLVMStructType type, |
| llvm::SetVector<StringRef> &stack) { |
| os << "<"; |
| if (type.isIdentified()) { |
| os << '"' << type.getName() << '"'; |
| // If we are printing a reference to one of the enclosing structs, just |
| // print the name and stop to avoid infinitely long output. |
| if (stack.count(type.getName())) { |
| os << '>'; |
| return; |
| } |
| os << ", "; |
| } |
| |
| printStructTypeBody(os, type, stack); |
| os << '>'; |
| } |
| |
| /// Prints a type containing a fixed number of elements. |
| template <typename TypeTy> |
| static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type, |
| llvm::SetVector<StringRef> &stack) { |
| os << '<' << type.getNumElements() << " x "; |
| printTypeImpl(os, type.getElementType(), stack); |
| os << '>'; |
| } |
| |
| /// Prints a function type. |
| static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType, |
| llvm::SetVector<StringRef> &stack) { |
| os << '<'; |
| printTypeImpl(os, funcType.getReturnType(), stack); |
| os << " ("; |
| llvm::interleaveComma( |
| funcType.getParams(), os, |
| [&os, &stack](LLVMType subtype) { printTypeImpl(os, subtype, stack); }); |
| if (funcType.isVarArg()) { |
| if (funcType.getNumParams() != 0) |
| os << ", "; |
| os << "..."; |
| } |
| os << ")>"; |
| } |
| |
| /// Prints the given LLVM dialect type recursively. This leverages closedness of |
| /// the LLVM dialect type system to avoid printing the dialect prefix |
| /// repeatedly. For recursive structures, only prints the name of the structure |
| /// when printing a self-reference. Note that this does not apply to sibling |
| /// references. For example, |
| /// struct<"a", (ptr<struct<"a">>)> |
| /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>, |
| /// ptr<struct<"b", (ptr<struct<"c">>)>>)> |
| /// note that "b" is printed twice. |
| static void printTypeImpl(llvm::raw_ostream &os, LLVMType type, |
| llvm::SetVector<StringRef> &stack) { |
| if (!type) { |
| os << "<<NULL-TYPE>>"; |
| return; |
| } |
| |
| unsigned kind = type.getKind(); |
| os << getTypeKeyword(type); |
| |
| // Trivial types only consist of their keyword. |
| if (LLVMType::FIRST_TRIVIAL_TYPE <= kind && |
| kind <= LLVMType::LAST_TRIVIAL_TYPE) |
| return; |
| |
| if (auto intType = type.dyn_cast<LLVMIntegerType>()) { |
| os << intType.getBitWidth(); |
| return; |
| } |
| |
| if (auto ptrType = type.dyn_cast<LLVMPointerType>()) { |
| os << '<'; |
| printTypeImpl(os, ptrType.getElementType(), stack); |
| if (ptrType.getAddressSpace() != 0) |
| os << ", " << ptrType.getAddressSpace(); |
| os << '>'; |
| return; |
| } |
| |
| if (auto arrayType = type.dyn_cast<LLVMArrayType>()) |
| return printArrayOrVectorType(os, arrayType, stack); |
| if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>()) |
| return printArrayOrVectorType(os, vectorType, stack); |
| |
| if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) { |
| os << "<? x " << vectorType.getMinNumElements() << " x "; |
| printTypeImpl(os, vectorType.getElementType(), stack); |
| os << '>'; |
| return; |
| } |
| |
| if (auto structType = type.dyn_cast<LLVMStructType>()) |
| return printStructType(os, structType, stack); |
| |
| printFunctionType(os, type.cast<LLVMFunctionType>(), stack); |
| } |
| |
| void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) { |
| llvm::SetVector<StringRef> stack; |
| return printTypeImpl(printer.getStream(), type, stack); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Parsing. |
| //===----------------------------------------------------------------------===// |
| |
| static LLVMType parseTypeImpl(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack); |
| |
| /// Helper to be chained with other parsing functions. |
| static ParseResult parseTypeImpl(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack, |
| LLVMType &result) { |
| result = parseTypeImpl(parser, stack); |
| return success(result != nullptr); |
| } |
| |
| /// Parses an LLVM dialect function type. |
| /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` |
| static LLVMFunctionType parseFunctionType(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack) { |
| Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); |
| LLVMType returnType; |
| if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) || |
| parser.parseLParen()) |
| return LLVMFunctionType(); |
| |
| // Function type without arguments. |
| if (succeeded(parser.parseOptionalRParen())) { |
| if (succeeded(parser.parseGreater())) |
| return LLVMFunctionType::getChecked(loc, returnType, {}, |
| /*isVarArg=*/false); |
| return LLVMFunctionType(); |
| } |
| |
| // Parse arguments. |
| SmallVector<LLVMType, 8> argTypes; |
| do { |
| if (succeeded(parser.parseOptionalEllipsis())) { |
| if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) |
| return LLVMFunctionType(); |
| return LLVMFunctionType::getChecked(loc, returnType, argTypes, |
| /*isVarArg=*/true); |
| } |
| |
| argTypes.push_back(parseTypeImpl(parser, stack)); |
| if (!argTypes.back()) |
| return LLVMFunctionType(); |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) |
| return LLVMFunctionType(); |
| return LLVMFunctionType::getChecked(loc, returnType, argTypes, |
| /*isVarArg=*/false); |
| } |
| |
| /// Parses an LLVM dialect pointer type. |
| /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` |
| static LLVMPointerType parsePointerType(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack) { |
| Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); |
| LLVMType elementType; |
| if (parser.parseLess() || parseTypeImpl(parser, stack, elementType)) |
| return LLVMPointerType(); |
| |
| unsigned addressSpace = 0; |
| if (succeeded(parser.parseOptionalComma()) && |
| failed(parser.parseInteger(addressSpace))) |
| return LLVMPointerType(); |
| if (failed(parser.parseGreater())) |
| return LLVMPointerType(); |
| return LLVMPointerType::getChecked(loc, elementType, addressSpace); |
| } |
| |
| /// Parses an LLVM dialect vector type. |
| /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` |
| /// Supports both fixed and scalable vectors. |
| static LLVMVectorType parseVectorType(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack) { |
| SmallVector<int64_t, 2> dims; |
| llvm::SMLoc dimPos; |
| LLVMType elementType; |
| Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); |
| if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || |
| parser.parseDimensionList(dims, /*allowDynamic=*/true) || |
| parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) |
| return LLVMVectorType(); |
| |
| // We parsed a generic dimension list, but vectors only support two forms: |
| // - single non-dynamic entry in the list (fixed vector); |
| // - two elements, the first dynamic (indicated by -1) and the second |
| // non-dynamic (scalable vector). |
| if (dims.empty() || dims.size() > 2 || |
| ((dims.size() == 2) ^ (dims[0] == -1)) || |
| (dims.size() == 2 && dims[1] == -1)) { |
| parser.emitError(dimPos) |
| << "expected '? x <integer> x <type>' or '<integer> x <type>'"; |
| return LLVMVectorType(); |
| } |
| |
| bool isScalable = dims.size() == 2; |
| if (isScalable) |
| return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]); |
| return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]); |
| } |
| |
| /// Parses an LLVM dialect array type. |
| /// llvm-type ::= `array<` integer `x` llvm-type `>` |
| static LLVMArrayType parseArrayType(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack) { |
| SmallVector<int64_t, 1> dims; |
| llvm::SMLoc sizePos; |
| LLVMType elementType; |
| Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); |
| if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || |
| parser.parseDimensionList(dims, /*allowDynamic=*/false) || |
| parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) |
| return LLVMArrayType(); |
| |
| if (dims.size() != 1) { |
| parser.emitError(sizePos) << "expected ? x <type>"; |
| return LLVMArrayType(); |
| } |
| |
| return LLVMArrayType::getChecked(loc, elementType, dims[0]); |
| } |
| |
| /// Attempts to set the body of an identified structure type. Reports a parsing |
| /// error at `subtypesLoc` in case of failure, uses `stack` to make sure the |
| /// types printed in the error message look like they did when parsed. |
| static LLVMStructType trySetStructBody(LLVMStructType type, |
| ArrayRef<LLVMType> subtypes, |
| bool isPacked, DialectAsmParser &parser, |
| llvm::SMLoc subtypesLoc, |
| llvm::SetVector<StringRef> &stack) { |
| for (LLVMType t : subtypes) { |
| if (!LLVMStructType::isValidElementType(t)) { |
| parser.emitError(subtypesLoc) |
| << "invalid LLVM structure element type: " << t; |
| return LLVMStructType(); |
| } |
| } |
| |
| if (succeeded(type.setBody(subtypes, isPacked))) |
| return type; |
| |
| std::string currentBody; |
| llvm::raw_string_ostream currentBodyStream(currentBody); |
| printStructTypeBody(currentBodyStream, type, stack); |
| auto diag = parser.emitError(subtypesLoc) |
| << "identified type already used with a different body"; |
| diag.attachNote() << "existing body: " << currentBodyStream.str(); |
| return LLVMStructType(); |
| } |
| |
| /// Parses an LLVM dialect structure type. |
| /// llvm-type ::= `struct<` (string-literal `,`)? `packed`? |
| /// `(` llvm-type-list `)` `>` |
| /// | `struct<` string-literal `>` |
| /// | `struct<` string-literal `, opaque>` |
| static LLVMStructType parseStructType(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack) { |
| Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); |
| |
| if (failed(parser.parseLess())) |
| return LLVMStructType(); |
| |
| // If we are parsing a self-reference to a recursive struct, i.e. the parsing |
| // stack already contains a struct with the same identifier, bail out after |
| // the name. |
| StringRef name; |
| bool isIdentified = succeeded(parser.parseOptionalString(&name)); |
| if (isIdentified) { |
| if (stack.count(name)) { |
| if (failed(parser.parseGreater())) |
| return LLVMStructType(); |
| return LLVMStructType::getIdentifiedChecked(loc, name); |
| } |
| if (failed(parser.parseComma())) |
| return LLVMStructType(); |
| } |
| |
| // Handle intentionally opaque structs. |
| llvm::SMLoc kwLoc = parser.getCurrentLocation(); |
| if (succeeded(parser.parseOptionalKeyword("opaque"))) { |
| if (!isIdentified) |
| return parser.emitError(kwLoc, "only identified structs can be opaque"), |
| LLVMStructType(); |
| if (failed(parser.parseGreater())) |
| return LLVMStructType(); |
| auto type = LLVMStructType::getOpaqueChecked(loc, name); |
| if (!type.isOpaque()) { |
| parser.emitError(kwLoc, "redeclaring defined struct as opaque"); |
| return LLVMStructType(); |
| } |
| return type; |
| } |
| |
| // Check for packedness. |
| bool isPacked = succeeded(parser.parseOptionalKeyword("packed")); |
| if (failed(parser.parseLParen())) |
| return LLVMStructType(); |
| |
| // Fast pass for structs with zero subtypes. |
| if (succeeded(parser.parseOptionalRParen())) { |
| if (failed(parser.parseGreater())) |
| return LLVMStructType(); |
| if (!isIdentified) |
| return LLVMStructType::getLiteralChecked(loc, {}, isPacked); |
| auto type = LLVMStructType::getIdentifiedChecked(loc, name); |
| return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack); |
| } |
| |
| // Parse subtypes. For identified structs, put the identifier of the struct on |
| // the stack to support self-references in the recursive calls. |
| SmallVector<LLVMType, 4> subtypes; |
| llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); |
| do { |
| if (isIdentified) |
| stack.insert(name); |
| LLVMType type = parseTypeImpl(parser, stack); |
| if (!type) |
| return LLVMStructType(); |
| subtypes.push_back(type); |
| if (isIdentified) |
| stack.pop_back(); |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (parser.parseRParen() || parser.parseGreater()) |
| return LLVMStructType(); |
| |
| // Construct the struct with body. |
| if (!isIdentified) |
| return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked); |
| auto type = LLVMStructType::getIdentifiedChecked(loc, name); |
| return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack); |
| } |
| |
| /// Parses one of the LLVM dialect types. |
| static LLVMType parseTypeImpl(DialectAsmParser &parser, |
| llvm::SetVector<StringRef> &stack) { |
| // Special case for integers (i[1-9][0-9]*) that are literals rather than |
| // keywords for the parser, so they are not caught by the main dispatch below. |
| // Try parsing it a built-in integer type instead. |
| Type maybeIntegerType; |
| MLIRContext *ctx = parser.getBuilder().getContext(); |
| llvm::SMLoc keyLoc = parser.getCurrentLocation(); |
| Location loc = parser.getEncodedSourceLoc(keyLoc); |
| OptionalParseResult result = parser.parseOptionalType(maybeIntegerType); |
| if (result.hasValue()) { |
| if (failed(*result)) |
| return LLVMType(); |
| |
| if (!maybeIntegerType.isSignlessInteger()) { |
| parser.emitError(keyLoc) << "unexpected type, expected i* or keyword"; |
| return LLVMType(); |
| } |
| return LLVMIntegerType::getChecked( |
| loc, maybeIntegerType.getIntOrFloatBitWidth()); |
| } |
| |
| // Dispatch to concrete functions. |
| StringRef key; |
| if (failed(parser.parseKeyword(&key))) |
| return LLVMType(); |
| |
| return llvm::StringSwitch<function_ref<LLVMType()>>(key) |
| .Case("void", [&] { return LLVMVoidType::get(ctx); }) |
| .Case("half", [&] { return LLVMHalfType::get(ctx); }) |
| .Case("bfloat", [&] { return LLVMBFloatType::get(ctx); }) |
| .Case("float", [&] { return LLVMFloatType::get(ctx); }) |
| .Case("double", [&] { return LLVMDoubleType::get(ctx); }) |
| .Case("fp128", [&] { return LLVMFP128Type::get(ctx); }) |
| .Case("x86_fp80", [&] { return LLVMX86FP80Type::get(ctx); }) |
| .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); }) |
| .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); }) |
| .Case("token", [&] { return LLVMTokenType::get(ctx); }) |
| .Case("label", [&] { return LLVMLabelType::get(ctx); }) |
| .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) |
| .Case("func", [&] { return parseFunctionType(parser, stack); }) |
| .Case("ptr", [&] { return parsePointerType(parser, stack); }) |
| .Case("vec", [&] { return parseVectorType(parser, stack); }) |
| .Case("array", [&] { return parseArrayType(parser, stack); }) |
| .Case("struct", [&] { return parseStructType(parser, stack); }) |
| .Default([&] { |
| parser.emitError(keyLoc) << "unknown LLVM type: " << key; |
| return LLVMType(); |
| })(); |
| } |
| |
| LLVMType mlir::LLVM::detail::parseType(DialectAsmParser &parser) { |
| llvm::SetVector<StringRef> stack; |
| return parseTypeImpl(parser, stack); |
| } |