| //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the parser for the MLIR textual form. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Parser.h" |
| #include "Lexer.h" |
| #include "mlir/Analysis/Verifier.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/IntegerSet.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/Support/STLExtras.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/ADT/bit.h" |
| #include "llvm/Support/PrettyStackTrace.h" |
| #include "llvm/Support/SMLoc.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include <algorithm> |
| using namespace mlir; |
| using llvm::MemoryBuffer; |
| using llvm::SMLoc; |
| using llvm::SourceMgr; |
| |
| namespace { |
| class Parser; |
| |
| //===----------------------------------------------------------------------===// |
| // SymbolState |
| //===----------------------------------------------------------------------===// |
| |
| /// This class contains record of any parsed top-level symbols. |
| struct SymbolState { |
| // A map from attribute alias identifier to Attribute. |
| llvm::StringMap<Attribute> attributeAliasDefinitions; |
| |
| // A map from type alias identifier to Type. |
| llvm::StringMap<Type> typeAliasDefinitions; |
| |
| /// A set of locations into the main parser memory buffer for each of the |
| /// active nested parsers. Given that some nested parsers, i.e. custom dialect |
| /// parsers, operate on a temporary memory buffer, this provides an anchor |
| /// point for emitting diagnostics. |
| SmallVector<llvm::SMLoc, 1> nestedParserLocs; |
| |
| /// The top-level lexer that contains the original memory buffer provided by |
| /// the user. This is used by nested parsers to get a properly encoded source |
| /// location. |
| Lexer *topLevelLexer = nullptr; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ParserState |
| //===----------------------------------------------------------------------===// |
| |
| /// This class refers to all of the state maintained globally by the parser, |
| /// such as the current lexer position etc. |
| struct ParserState { |
| ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, |
| SymbolState &symbols) |
| : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), |
| symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { |
| // Set the top level lexer for the symbol state if one doesn't exist. |
| if (!symbols.topLevelLexer) |
| symbols.topLevelLexer = &lex; |
| } |
| ~ParserState() { |
| // Reset the top level lexer if it refers the lexer in our state. |
| if (symbols.topLevelLexer == &lex) |
| symbols.topLevelLexer = nullptr; |
| } |
| ParserState(const ParserState &) = delete; |
| void operator=(const ParserState &) = delete; |
| |
| /// The context we're parsing into. |
| MLIRContext *const context; |
| |
| /// The lexer for the source file we're parsing. |
| Lexer lex; |
| |
| /// This is the next token that hasn't been consumed yet. |
| Token curToken; |
| |
| /// The current state for symbol parsing. |
| SymbolState &symbols; |
| |
| /// The depth of this parser in the nested parsing stack. |
| size_t parserDepth; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Parser |
| //===----------------------------------------------------------------------===// |
| |
| /// This class implement support for parsing global entities like types and |
| /// shared entities like SSA names. It is intended to be subclassed by |
| /// specialized subparsers that include state, e.g. when a local symbol table. |
| class Parser { |
| public: |
| Builder builder; |
| |
| Parser(ParserState &state) : builder(state.context), state(state) {} |
| |
| // Helper methods to get stuff from the parser-global state. |
| ParserState &getState() const { return state; } |
| MLIRContext *getContext() const { return state.context; } |
| const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } |
| |
| /// Parse a comma-separated list of elements up until the specified end token. |
| ParseResult |
| parseCommaSeparatedListUntil(Token::Kind rightToken, |
| const std::function<ParseResult()> &parseElement, |
| bool allowEmptyList = true); |
| |
| /// Parse a comma separated list of elements that must have at least one entry |
| /// in it. |
| ParseResult |
| parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); |
| |
| ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); |
| |
| // We have two forms of parsing methods - those that return a non-null |
| // pointer on success, and those that return a ParseResult to indicate whether |
| // they returned a failure. The second class fills in by-reference arguments |
| // as the results of their action. |
| |
| //===--------------------------------------------------------------------===// |
| // Error Handling |
| //===--------------------------------------------------------------------===// |
| |
| /// Emit an error and return failure. |
| InFlightDiagnostic emitError(const Twine &message = {}) { |
| return emitError(state.curToken.getLoc(), message); |
| } |
| InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); |
| |
| /// Encode the specified source location information into an attribute for |
| /// attachment to the IR. |
| Location getEncodedSourceLocation(llvm::SMLoc loc) { |
| // If there are no active nested parsers, we can get the encoded source |
| // location directly. |
| if (state.parserDepth == 0) |
| return state.lex.getEncodedSourceLocation(loc); |
| // Otherwise, we need to re-encode it to point to the top level buffer. |
| return state.symbols.topLevelLexer->getEncodedSourceLocation( |
| remapLocationToTopLevelBuffer(loc)); |
| } |
| |
| /// Remaps the given SMLoc to the top level lexer of the parser. This is used |
| /// to adjust locations of potentially nested parsers to ensure that they can |
| /// be emitted properly as diagnostics. |
| llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { |
| // If there are no active nested parsers, we can return location directly. |
| SymbolState &symbols = state.symbols; |
| if (state.parserDepth == 0) |
| return loc; |
| assert(symbols.topLevelLexer && "expected valid top-level lexer"); |
| |
| // Otherwise, we need to remap the location to the main parser. This is |
| // simply offseting the location onto the location of the last nested |
| // parser. |
| size_t offset = loc.getPointer() - state.lex.getBufferBegin(); |
| auto *rawLoc = |
| symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; |
| return llvm::SMLoc::getFromPointer(rawLoc); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Token Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Return the current token the parser is inspecting. |
| const Token &getToken() const { return state.curToken; } |
| StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } |
| |
| /// If the current token has the specified kind, consume it and return true. |
| /// If not, return false. |
| bool consumeIf(Token::Kind kind) { |
| if (state.curToken.isNot(kind)) |
| return false; |
| consumeToken(kind); |
| return true; |
| } |
| |
| /// Advance the current lexer onto the next token. |
| void consumeToken() { |
| assert(state.curToken.isNot(Token::eof, Token::error) && |
| "shouldn't advance past EOF or errors"); |
| state.curToken = state.lex.lexToken(); |
| } |
| |
| /// Advance the current lexer onto the next token, asserting what the expected |
| /// current token is. This is preferred to the above method because it leads |
| /// to more self-documenting code with better checking. |
| void consumeToken(Token::Kind kind) { |
| assert(state.curToken.is(kind) && "consumed an unexpected token"); |
| consumeToken(); |
| } |
| |
| /// Consume the specified token if present and return success. On failure, |
| /// output a diagnostic and return failure. |
| ParseResult parseToken(Token::Kind expectedToken, const Twine &message); |
| |
| //===--------------------------------------------------------------------===// |
| // Type Parsing |
| //===--------------------------------------------------------------------===// |
| |
| ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); |
| ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); |
| ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); |
| |
| /// Parse an arbitrary type. |
| Type parseType(); |
| |
| /// Parse a complex type. |
| Type parseComplexType(); |
| |
| /// Parse an extended type. |
| Type parseExtendedType(); |
| |
| /// Parse a function type. |
| Type parseFunctionType(); |
| |
| /// Parse a memref type. |
| Type parseMemRefType(); |
| |
| /// Parse a non function type. |
| Type parseNonFunctionType(); |
| |
| /// Parse a tensor type. |
| Type parseTensorType(); |
| |
| /// Parse a tuple type. |
| Type parseTupleType(); |
| |
| /// Parse a vector type. |
| VectorType parseVectorType(); |
| ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, |
| bool allowDynamic = true); |
| ParseResult parseXInDimensionList(); |
| |
| /// Parse strided layout specification. |
| ParseResult parseStridedLayout(int64_t &offset, |
| SmallVectorImpl<int64_t> &strides); |
| |
| // Parse a brace-delimiter list of comma-separated integers with `?` as an |
| // unknown marker. |
| ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); |
| |
| //===--------------------------------------------------------------------===// |
| // Attribute Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse an arbitrary attribute with an optional type. |
| Attribute parseAttribute(Type type = {}); |
| |
| /// Parse an attribute dictionary. |
| ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); |
| |
| /// Parse an extended attribute. |
| Attribute parseExtendedAttr(Type type); |
| |
| /// Parse a float attribute. |
| Attribute parseFloatAttr(Type type, bool isNegative); |
| |
| /// Parse a decimal or a hexadecimal literal, which can be either an integer |
| /// or a float attribute. |
| Attribute parseDecOrHexAttr(Type type, bool isNegative); |
| |
| /// Parse an opaque elements attribute. |
| Attribute parseOpaqueElementsAttr(Type attrType); |
| |
| /// Parse a dense elements attribute. |
| Attribute parseDenseElementsAttr(Type attrType); |
| ShapedType parseElementsLiteralType(Type type); |
| |
| /// Parse a sparse elements attribute. |
| Attribute parseSparseElementsAttr(Type attrType); |
| |
| //===--------------------------------------------------------------------===// |
| // Location Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse an inline location. |
| ParseResult parseLocation(LocationAttr &loc); |
| |
| /// Parse a raw location instance. |
| ParseResult parseLocationInstance(LocationAttr &loc); |
| |
| /// Parse a callsite location instance. |
| ParseResult parseCallSiteLocation(LocationAttr &loc); |
| |
| /// Parse a fused location instance. |
| ParseResult parseFusedLocation(LocationAttr &loc); |
| |
| /// Parse a name or FileLineCol location instance. |
| ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); |
| |
| /// Parse an optional trailing location. |
| /// |
| /// trailing-location ::= (`loc` `(` location `)`)? |
| /// |
| ParseResult parseOptionalTrailingLocation(Location &loc) { |
| // If there is a 'loc' we parse a trailing location. |
| if (!getToken().is(Token::kw_loc)) |
| return success(); |
| |
| // Parse the location. |
| LocationAttr directLoc; |
| if (parseLocation(directLoc)) |
| return failure(); |
| loc = directLoc; |
| return success(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Affine Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a reference to either an affine map, or an integer set. |
| ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, |
| IntegerSet &set); |
| ParseResult parseAffineMapReference(AffineMap &map); |
| ParseResult parseIntegerSetReference(IntegerSet &set); |
| |
| /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. |
| ParseResult |
| parseAffineMapOfSSAIds(AffineMap &map, |
| function_ref<ParseResult(bool)> parseElement, |
| OpAsmParser::Delimiter delimiter); |
| |
| private: |
| /// The Parser is subclassed and reinstantiated. Do not add additional |
| /// non-trivial state here, add it to the ParserState class. |
| ParserState &state; |
| }; |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Helper methods. |
| //===----------------------------------------------------------------------===// |
| |
| /// Parse a comma separated list of elements that must have at least one entry |
| /// in it. |
| ParseResult Parser::parseCommaSeparatedList( |
| const std::function<ParseResult()> &parseElement) { |
| // Non-empty case starts with an element. |
| if (parseElement()) |
| return failure(); |
| |
| // Otherwise we have a list of comma separated elements. |
| while (consumeIf(Token::comma)) { |
| if (parseElement()) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| /// Parse a comma-separated list of elements, terminated with an arbitrary |
| /// token. This allows empty lists if allowEmptyList is true. |
| /// |
| /// abstract-list ::= rightToken // if allowEmptyList == true |
| /// abstract-list ::= element (',' element)* rightToken |
| /// |
| ParseResult Parser::parseCommaSeparatedListUntil( |
| Token::Kind rightToken, const std::function<ParseResult()> &parseElement, |
| bool allowEmptyList) { |
| // Handle the empty case. |
| if (getToken().is(rightToken)) { |
| if (!allowEmptyList) |
| return emitError("expected list element"); |
| consumeToken(rightToken); |
| return success(); |
| } |
| |
| if (parseCommaSeparatedList(parseElement) || |
| parseToken(rightToken, "expected ',' or '" + |
| Token::getTokenSpelling(rightToken) + "'")) |
| return failure(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DialectAsmParser |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class provides the main implementation of the DialectAsmParser that |
| /// allows for dialects to parse attributes and types. This allows for dialect |
| /// hooking into the main MLIR parsing logic. |
| class CustomDialectAsmParser : public DialectAsmParser { |
| public: |
| CustomDialectAsmParser(StringRef fullSpec, Parser &parser) |
| : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), |
| parser(parser) {} |
| ~CustomDialectAsmParser() override {} |
| |
| /// Emit a diagnostic at the specified location and return failure. |
| InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { |
| return parser.emitError(loc, message); |
| } |
| |
| /// Return a builder which provides useful access to MLIRContext, global |
| /// objects like types and attributes. |
| Builder &getBuilder() const override { return parser.builder; } |
| |
| /// Get the location of the next token and store it into the argument. This |
| /// always succeeds. |
| llvm::SMLoc getCurrentLocation() override { |
| return parser.getToken().getLoc(); |
| } |
| |
| /// Return the location of the original name token. |
| llvm::SMLoc getNameLoc() const override { return nameLoc; } |
| |
| /// Re-encode the given source location as an MLIR location and return it. |
| Location getEncodedSourceLoc(llvm::SMLoc loc) override { |
| return parser.getEncodedSourceLocation(loc); |
| } |
| |
| /// Returns the full specification of the symbol being parsed. This allows |
| /// for using a separate parser if necessary. |
| StringRef getFullSymbolSpec() const override { return fullSpec; } |
| |
| /// Parse a floating point value from the stream. |
| ParseResult parseFloat(double &result) override { |
| bool negative = parser.consumeIf(Token::minus); |
| Token curTok = parser.getToken(); |
| |
| // Check for a floating point value. |
| if (curTok.is(Token::floatliteral)) { |
| auto val = curTok.getFloatingPointValue(); |
| if (!val.hasValue()) |
| return emitError(curTok.getLoc(), "floating point value too large"); |
| parser.consumeToken(Token::floatliteral); |
| result = negative ? -*val : *val; |
| return success(); |
| } |
| |
| // TODO(riverriddle) support hex floating point values. |
| return emitError(getCurrentLocation(), "expected floating point literal"); |
| } |
| |
| /// Parse an optional integer value from the stream. |
| OptionalParseResult parseOptionalInteger(uint64_t &result) override { |
| Token curToken = parser.getToken(); |
| if (curToken.isNot(Token::integer, Token::minus)) |
| return llvm::None; |
| |
| bool negative = parser.consumeIf(Token::minus); |
| Token curTok = parser.getToken(); |
| if (parser.parseToken(Token::integer, "expected integer value")) |
| return failure(); |
| |
| auto val = curTok.getUInt64IntegerValue(); |
| if (!val) |
| return emitError(curTok.getLoc(), "integer value too large"); |
| result = negative ? -*val : *val; |
| return success(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Token Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a `->` token. |
| ParseResult parseArrow() override { |
| return parser.parseToken(Token::arrow, "expected '->'"); |
| } |
| |
| /// Parses a `->` if present. |
| ParseResult parseOptionalArrow() override { |
| return success(parser.consumeIf(Token::arrow)); |
| } |
| |
| /// Parse a '{' token. |
| ParseResult parseLBrace() override { |
| return parser.parseToken(Token::l_brace, "expected '{'"); |
| } |
| |
| /// Parse a '{' token if present |
| ParseResult parseOptionalLBrace() override { |
| return success(parser.consumeIf(Token::l_brace)); |
| } |
| |
| /// Parse a `}` token. |
| ParseResult parseRBrace() override { |
| return parser.parseToken(Token::r_brace, "expected '}'"); |
| } |
| |
| /// Parse a `}` token if present |
| ParseResult parseOptionalRBrace() override { |
| return success(parser.consumeIf(Token::r_brace)); |
| } |
| |
| /// Parse a `:` token. |
| ParseResult parseColon() override { |
| return parser.parseToken(Token::colon, "expected ':'"); |
| } |
| |
| /// Parse a `:` token if present. |
| ParseResult parseOptionalColon() override { |
| return success(parser.consumeIf(Token::colon)); |
| } |
| |
| /// Parse a `,` token. |
| ParseResult parseComma() override { |
| return parser.parseToken(Token::comma, "expected ','"); |
| } |
| |
| /// Parse a `,` token if present. |
| ParseResult parseOptionalComma() override { |
| return success(parser.consumeIf(Token::comma)); |
| } |
| |
| /// Parses a `...` if present. |
| ParseResult parseOptionalEllipsis() override { |
| return success(parser.consumeIf(Token::ellipsis)); |
| } |
| |
| /// Parse a `=` token. |
| ParseResult parseEqual() override { |
| return parser.parseToken(Token::equal, "expected '='"); |
| } |
| |
| /// Parse a '<' token. |
| ParseResult parseLess() override { |
| return parser.parseToken(Token::less, "expected '<'"); |
| } |
| |
| /// Parse a `<` token if present. |
| ParseResult parseOptionalLess() override { |
| return success(parser.consumeIf(Token::less)); |
| } |
| |
| /// Parse a '>' token. |
| ParseResult parseGreater() override { |
| return parser.parseToken(Token::greater, "expected '>'"); |
| } |
| |
| /// Parse a `>` token if present. |
| ParseResult parseOptionalGreater() override { |
| return success(parser.consumeIf(Token::greater)); |
| } |
| |
| /// Parse a `(` token. |
| ParseResult parseLParen() override { |
| return parser.parseToken(Token::l_paren, "expected '('"); |
| } |
| |
| /// Parses a '(' if present. |
| ParseResult parseOptionalLParen() override { |
| return success(parser.consumeIf(Token::l_paren)); |
| } |
| |
| /// Parse a `)` token. |
| ParseResult parseRParen() override { |
| return parser.parseToken(Token::r_paren, "expected ')'"); |
| } |
| |
| /// Parses a ')' if present. |
| ParseResult parseOptionalRParen() override { |
| return success(parser.consumeIf(Token::r_paren)); |
| } |
| |
| /// Parse a `[` token. |
| ParseResult parseLSquare() override { |
| return parser.parseToken(Token::l_square, "expected '['"); |
| } |
| |
| /// Parses a '[' if present. |
| ParseResult parseOptionalLSquare() override { |
| return success(parser.consumeIf(Token::l_square)); |
| } |
| |
| /// Parse a `]` token. |
| ParseResult parseRSquare() override { |
| return parser.parseToken(Token::r_square, "expected ']'"); |
| } |
| |
| /// Parses a ']' if present. |
| ParseResult parseOptionalRSquare() override { |
| return success(parser.consumeIf(Token::r_square)); |
| } |
| |
| /// Parses a '?' if present. |
| ParseResult parseOptionalQuestion() override { |
| return success(parser.consumeIf(Token::question)); |
| } |
| |
| /// Parses a '*' if present. |
| ParseResult parseOptionalStar() override { |
| return success(parser.consumeIf(Token::star)); |
| } |
| |
| /// Returns if the current token corresponds to a keyword. |
| bool isCurrentTokenAKeyword() const { |
| return parser.getToken().is(Token::bare_identifier) || |
| parser.getToken().isKeyword(); |
| } |
| |
| /// Parse the given keyword if present. |
| ParseResult parseOptionalKeyword(StringRef keyword) override { |
| // Check that the current token has the same spelling. |
| if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) |
| return failure(); |
| parser.consumeToken(); |
| return success(); |
| } |
| |
| /// Parse a keyword, if present, into 'keyword'. |
| ParseResult parseOptionalKeyword(StringRef *keyword) override { |
| // Check that the current token is a keyword. |
| if (!isCurrentTokenAKeyword()) |
| return failure(); |
| |
| *keyword = parser.getTokenSpelling(); |
| parser.consumeToken(); |
| return success(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Attribute Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse an arbitrary attribute and return it in result. |
| ParseResult parseAttribute(Attribute &result, Type type) override { |
| result = parser.parseAttribute(type); |
| return success(static_cast<bool>(result)); |
| } |
| |
| /// Parse an affine map instance into 'map'. |
| ParseResult parseAffineMap(AffineMap &map) override { |
| return parser.parseAffineMapReference(map); |
| } |
| |
| /// Parse an integer set instance into 'set'. |
| ParseResult printIntegerSet(IntegerSet &set) override { |
| return parser.parseIntegerSetReference(set); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Type Parsing |
| //===--------------------------------------------------------------------===// |
| |
| ParseResult parseType(Type &result) override { |
| result = parser.parseType(); |
| return success(static_cast<bool>(result)); |
| } |
| |
| ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
| bool allowDynamic) override { |
| return parser.parseDimensionListRanked(dimensions, allowDynamic); |
| } |
| |
| private: |
| /// The full symbol specification. |
| StringRef fullSpec; |
| |
| /// The source location of the dialect symbol. |
| SMLoc nameLoc; |
| |
| /// The main parser. |
| Parser &parser; |
| }; |
| } // namespace |
| |
| /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, |
| /// and may be recursive. Return with the 'prettyName' StringRef encompassing |
| /// the entire pretty name. |
| /// |
| /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' |
| /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body |
| /// | '(' pretty-dialect-sym-contents+ ')' |
| /// | '[' pretty-dialect-sym-contents+ ']' |
| /// | '{' pretty-dialect-sym-contents+ '}' |
| /// | '[^[<({>\])}\0]+' |
| /// |
| ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { |
| // Pretty symbol names are a relatively unstructured format that contains a |
| // series of properly nested punctuation, with anything else in the middle. |
| // Scan ahead to find it and consume it if successful, otherwise emit an |
| // error. |
| auto *curPtr = getTokenSpelling().data(); |
| |
| SmallVector<char, 8> nestedPunctuation; |
| |
| // Scan over the nested punctuation, bailing out on error and consuming until |
| // we find the end. We know that we're currently looking at the '<', so we |
| // can go until we find the matching '>' character. |
| assert(*curPtr == '<'); |
| do { |
| char c = *curPtr++; |
| switch (c) { |
| case '\0': |
| // This also handles the EOF case. |
| return emitError("unexpected nul or EOF in pretty dialect name"); |
| case '<': |
| case '[': |
| case '(': |
| case '{': |
| nestedPunctuation.push_back(c); |
| continue; |
| |
| case '-': |
| // The sequence `->` is treated as special token. |
| if (*curPtr == '>') |
| ++curPtr; |
| continue; |
| |
| case '>': |
| if (nestedPunctuation.pop_back_val() != '<') |
| return emitError("unbalanced '>' character in pretty dialect name"); |
| break; |
| case ']': |
| if (nestedPunctuation.pop_back_val() != '[') |
| return emitError("unbalanced ']' character in pretty dialect name"); |
| break; |
| case ')': |
| if (nestedPunctuation.pop_back_val() != '(') |
| return emitError("unbalanced ')' character in pretty dialect name"); |
| break; |
| case '}': |
| if (nestedPunctuation.pop_back_val() != '{') |
| return emitError("unbalanced '}' character in pretty dialect name"); |
| break; |
| |
| default: |
| continue; |
| } |
| } while (!nestedPunctuation.empty()); |
| |
| // Ok, we succeeded, remember where we stopped, reset the lexer to know it is |
| // consuming all this stuff, and return. |
| state.lex.resetPointer(curPtr); |
| |
| unsigned length = curPtr - prettyName.begin(); |
| prettyName = StringRef(prettyName.begin(), length); |
| consumeToken(); |
| return success(); |
| } |
| |
| /// Parse an extended dialect symbol. |
| template <typename Symbol, typename SymbolAliasMap, typename CreateFn> |
| static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, |
| SymbolAliasMap &aliases, |
| CreateFn &&createSymbol) { |
| // Parse the dialect namespace. |
| StringRef identifier = p.getTokenSpelling().drop_front(); |
| auto loc = p.getToken().getLoc(); |
| p.consumeToken(identifierTok); |
| |
| // If there is no '<' token following this, and if the typename contains no |
| // dot, then we are parsing a symbol alias. |
| if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { |
| // Check for an alias for this type. |
| auto aliasIt = aliases.find(identifier); |
| if (aliasIt == aliases.end()) |
| return (p.emitError("undefined symbol alias id '" + identifier + "'"), |
| nullptr); |
| return aliasIt->second; |
| } |
| |
| // Otherwise, we are parsing a dialect-specific symbol. If the name contains |
| // a dot, then this is the "pretty" form. If not, it is the verbose form that |
| // looks like <"...">. |
| std::string symbolData; |
| auto dialectName = identifier; |
| |
| // Handle the verbose form, where "identifier" is a simple dialect name. |
| if (!identifier.contains('.')) { |
| // Consume the '<'. |
| if (p.parseToken(Token::less, "expected '<' in dialect type")) |
| return nullptr; |
| |
| // Parse the symbol specific data. |
| if (p.getToken().isNot(Token::string)) |
| return (p.emitError("expected string literal data in dialect symbol"), |
| nullptr); |
| symbolData = p.getToken().getStringValue(); |
| loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); |
| p.consumeToken(Token::string); |
| |
| // Consume the '>'. |
| if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) |
| return nullptr; |
| } else { |
| // Ok, the dialect name is the part of the identifier before the dot, the |
| // part after the dot is the dialect's symbol, or the start thereof. |
| auto dotHalves = identifier.split('.'); |
| dialectName = dotHalves.first; |
| auto prettyName = dotHalves.second; |
| loc = llvm::SMLoc::getFromPointer(prettyName.data()); |
| |
| // If the dialect's symbol is followed immediately by a <, then lex the body |
| // of it into prettyName. |
| if (p.getToken().is(Token::less) && |
| prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { |
| if (p.parsePrettyDialectSymbolName(prettyName)) |
| return nullptr; |
| } |
| |
| symbolData = prettyName.str(); |
| } |
| |
| // Record the name location of the type remapped to the top level buffer. |
| llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); |
| p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); |
| |
| // Call into the provided symbol construction function. |
| Symbol sym = createSymbol(dialectName, symbolData, loc); |
| |
| // Pop the last parser location. |
| p.getState().symbols.nestedParserLocs.pop_back(); |
| return sym; |
| } |
| |
| /// Parses a symbol, of type 'T', and returns it if parsing was successful. If |
| /// parsing failed, nullptr is returned. The number of bytes read from the input |
| /// string is returned in 'numRead'. |
| template <typename T, typename ParserFn> |
| static T parseSymbol(StringRef inputStr, MLIRContext *context, |
| SymbolState &symbolState, ParserFn &&parserFn, |
| size_t *numRead = nullptr) { |
| SourceMgr sourceMgr; |
| auto memBuffer = MemoryBuffer::getMemBuffer( |
| inputStr, /*BufferName=*/"<mlir_parser_buffer>", |
| /*RequiresNullTerminator=*/false); |
| sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); |
| ParserState state(sourceMgr, context, symbolState); |
| Parser parser(state); |
| |
| Token startTok = parser.getToken(); |
| T symbol = parserFn(parser); |
| if (!symbol) |
| return T(); |
| |
| // If 'numRead' is valid, then provide the number of bytes that were read. |
| Token endTok = parser.getToken(); |
| if (numRead) { |
| *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - |
| startTok.getLoc().getPointer()); |
| |
| // Otherwise, ensure that all of the tokens were parsed. |
| } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { |
| parser.emitError(endTok.getLoc(), "encountered unexpected token"); |
| return T(); |
| } |
| return symbol; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Error Handling |
| //===----------------------------------------------------------------------===// |
| |
| InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { |
| auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); |
| |
| // If we hit a parse error in response to a lexer error, then the lexer |
| // already reported the error. |
| if (getToken().is(Token::error)) |
| diag.abandon(); |
| return diag; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Token Parsing |
| //===----------------------------------------------------------------------===// |
| |
| /// Consume the specified token if present and return success. On failure, |
| /// output a diagnostic and return failure. |
| ParseResult Parser::parseToken(Token::Kind expectedToken, |
| const Twine &message) { |
| if (consumeIf(expectedToken)) |
| return success(); |
| return emitError(message); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Type Parsing |
| //===----------------------------------------------------------------------===// |
| |
| /// Parse an arbitrary type. |
| /// |
| /// type ::= function-type |
| /// | non-function-type |
| /// |
| Type Parser::parseType() { |
| if (getToken().is(Token::l_paren)) |
| return parseFunctionType(); |
| return parseNonFunctionType(); |
| } |
| |
| /// Parse a function result type. |
| /// |
| /// function-result-type ::= type-list-parens |
| /// | non-function-type |
| /// |
| ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { |
| if (getToken().is(Token::l_paren)) |
| return parseTypeListParens(elements); |
| |
| Type t = parseNonFunctionType(); |
| if (!t) |
| return failure(); |
| elements.push_back(t); |
| return success(); |
| } |
| |
| /// Parse a list of types without an enclosing parenthesis. The list must have |
| /// at least one member. |
| /// |
| /// type-list-no-parens ::= type (`,` type)* |
| /// |
| ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { |
| auto parseElt = [&]() -> ParseResult { |
| auto elt = parseType(); |
| elements.push_back(elt); |
| return elt ? success() : failure(); |
| }; |
| |
| return parseCommaSeparatedList(parseElt); |
| } |
| |
| /// Parse a parenthesized list of types. |
| /// |
| /// type-list-parens ::= `(` `)` |
| /// | `(` type-list-no-parens `)` |
| /// |
| ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { |
| if (parseToken(Token::l_paren, "expected '('")) |
| return failure(); |
| |
| // Handle empty lists. |
| if (getToken().is(Token::r_paren)) |
| return consumeToken(), success(); |
| |
| if (parseTypeListNoParens(elements) || |
| parseToken(Token::r_paren, "expected ')'")) |
| return failure(); |
| return success(); |
| } |
| |
| /// Parse a complex type. |
| /// |
| /// complex-type ::= `complex` `<` type `>` |
| /// |
| Type Parser::parseComplexType() { |
| consumeToken(Token::kw_complex); |
| |
| // Parse the '<'. |
| if (parseToken(Token::less, "expected '<' in complex type")) |
| return nullptr; |
| |
| llvm::SMLoc elementTypeLoc = getToken().getLoc(); |
| auto elementType = parseType(); |
| if (!elementType || |
| parseToken(Token::greater, "expected '>' in complex type")) |
| return nullptr; |
| if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) |
| return emitError(elementTypeLoc, "invalid element type for complex"), |
| nullptr; |
| |
| return ComplexType::get(elementType); |
| } |
| |
| /// Parse an extended type. |
| /// |
| /// extended-type ::= (dialect-type | type-alias) |
| /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` |
| /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? |
| /// type-alias ::= `!` alias-name |
| /// |
| Type Parser::parseExtendedType() { |
| return parseExtendedSymbol<Type>( |
| *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, |
| [&](StringRef dialectName, StringRef symbolData, |
| llvm::SMLoc loc) -> Type { |
| // If we found a registered dialect, then ask it to parse the type. |
| if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { |
| return parseSymbol<Type>( |
| symbolData, state.context, state.symbols, [&](Parser &parser) { |
| CustomDialectAsmParser customParser(symbolData, parser); |
| return dialect->parseType(customParser); |
| }); |
| } |
| |
| // Otherwise, form a new opaque type. |
| return OpaqueType::getChecked( |
| Identifier::get(dialectName, state.context), symbolData, |
| state.context, getEncodedSourceLocation(loc)); |
| }); |
| } |
| |
| /// Parse a function type. |
| /// |
| /// function-type ::= type-list-parens `->` function-result-type |
| /// |
| Type Parser::parseFunctionType() { |
| assert(getToken().is(Token::l_paren)); |
| |
| SmallVector<Type, 4> arguments, results; |
| if (parseTypeListParens(arguments) || |
| parseToken(Token::arrow, "expected '->' in function type") || |
| parseFunctionResultTypes(results)) |
| return nullptr; |
| |
| return builder.getFunctionType(arguments, results); |
| } |
| |
| /// Parse the offset and strides from a strided layout specification. |
| /// |
| /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list |
| /// |
| ParseResult Parser::parseStridedLayout(int64_t &offset, |
| SmallVectorImpl<int64_t> &strides) { |
| // Parse offset. |
| consumeToken(Token::kw_offset); |
| if (!consumeIf(Token::colon)) |
| return emitError("expected colon after `offset` keyword"); |
| auto maybeOffset = getToken().getUnsignedIntegerValue(); |
| bool question = getToken().is(Token::question); |
| if (!maybeOffset && !question) |
| return emitError("invalid offset"); |
| offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) |
| : MemRefType::getDynamicStrideOrOffset(); |
| consumeToken(); |
| |
| if (!consumeIf(Token::comma)) |
| return emitError("expected comma after offset value"); |
| |
| // Parse stride list. |
| if (!consumeIf(Token::kw_strides)) |
| return emitError("expected `strides` keyword after offset specification"); |
| if (!consumeIf(Token::colon)) |
| return emitError("expected colon after `strides` keyword"); |
| if (failed(parseStrideList(strides))) |
| return emitError("invalid braces-enclosed stride list"); |
| if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) |
| return emitError("invalid memref stride"); |
| |
| return success(); |
| } |
| |
| /// Parse a memref type. |
| /// |
| /// memref-type ::= ranked-memref-type | unranked-memref-type |
| /// |
| /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type |
| /// (`,` semi-affine-map-composition)? (`,` |
| /// memory-space)? `>` |
| /// |
| /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` |
| /// |
| /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map |
| /// memory-space ::= integer-literal /* | TODO: address-space-id */ |
| /// |
| Type Parser::parseMemRefType() { |
| consumeToken(Token::kw_memref); |
| |
| if (parseToken(Token::less, "expected '<' in memref type")) |
| return nullptr; |
| |
| bool isUnranked; |
| SmallVector<int64_t, 4> dimensions; |
| |
| if (consumeIf(Token::star)) { |
| // This is an unranked memref type. |
| isUnranked = true; |
| if (parseXInDimensionList()) |
| return nullptr; |
| |
| } else { |
| isUnranked = false; |
| if (parseDimensionListRanked(dimensions)) |
| return nullptr; |
| } |
| |
| // Parse the element type. |
| auto typeLoc = getToken().getLoc(); |
| auto elementType = parseType(); |
| if (!elementType) |
| return nullptr; |
| |
| // Check that memref is formed from allowed types. |
| if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() && |
| !elementType.isa<ComplexType>()) |
| return emitError(typeLoc, "invalid memref element type"), nullptr; |
| |
| // Parse semi-affine-map-composition. |
| SmallVector<AffineMap, 2> affineMapComposition; |
| Optional<unsigned> memorySpace; |
| unsigned numDims = dimensions.size(); |
| |
| auto parseElt = [&]() -> ParseResult { |
| // Check for the memory space. |
| if (getToken().is(Token::integer)) { |
| if (memorySpace) |
| return emitError("multiple memory spaces specified in memref type"); |
| memorySpace = getToken().getUnsignedIntegerValue(); |
| if (!memorySpace.hasValue()) |
| return emitError("invalid memory space in memref type"); |
| consumeToken(Token::integer); |
| return success(); |
| } |
| if (isUnranked) |
| return emitError("cannot have affine map for unranked memref type"); |
| if (memorySpace) |
| return emitError("expected memory space to be last in memref type"); |
| |
| AffineMap map; |
| llvm::SMLoc mapLoc = getToken().getLoc(); |
| if (getToken().is(Token::kw_offset)) { |
| int64_t offset; |
| SmallVector<int64_t, 4> strides; |
| if (failed(parseStridedLayout(offset, strides))) |
| return failure(); |
| // Construct strided affine map. |
| map = makeStridedLinearLayoutMap(strides, offset, state.context); |
| } else { |
| // Parse an affine map attribute. |
| auto affineMap = parseAttribute(); |
| if (!affineMap) |
| return failure(); |
| auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>(); |
| if (!affineMapAttr) |
| return emitError("expected affine map in memref type"); |
| map = affineMapAttr.getValue(); |
| } |
| |
| if (map.getNumDims() != numDims) { |
| size_t i = affineMapComposition.size(); |
| return emitError(mapLoc, "memref affine map dimension mismatch between ") |
| << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) |
| << " and affine map" << i + 1 << ": " << numDims |
| << " != " << map.getNumDims(); |
| } |
| numDims = map.getNumResults(); |
| affineMapComposition.push_back(map); |
| return success(); |
| }; |
| |
| // Parse a list of mappings and address space if present. |
| if (!consumeIf(Token::greater)) { |
| // Parse comma separated list of affine maps, followed by memory space. |
| if (parseToken(Token::comma, "expected ',' or '>' in memref type") || |
| parseCommaSeparatedListUntil(Token::greater, parseElt, |
| /*allowEmptyList=*/false)) { |
| return nullptr; |
| } |
| } |
| |
| if (isUnranked) |
| return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); |
| |
| return MemRefType::get(dimensions, elementType, affineMapComposition, |
| memorySpace.getValueOr(0)); |
| } |
| |
| /// Parse any type except the function type. |
| /// |
| /// non-function-type ::= integer-type |
| /// | index-type |
| /// | float-type |
| /// | extended-type |
| /// | vector-type |
| /// | tensor-type |
| /// | memref-type |
| /// | complex-type |
| /// | tuple-type |
| /// | none-type |
| /// |
| /// index-type ::= `index` |
| /// float-type ::= `f16` | `bf16` | `f32` | `f64` |
| /// none-type ::= `none` |
| /// |
| Type Parser::parseNonFunctionType() { |
| switch (getToken().getKind()) { |
| default: |
| return (emitError("expected non-function type"), nullptr); |
| case Token::kw_memref: |
| return parseMemRefType(); |
| case Token::kw_tensor: |
| return parseTensorType(); |
| case Token::kw_complex: |
| return parseComplexType(); |
| case Token::kw_tuple: |
| return parseTupleType(); |
| case Token::kw_vector: |
| return parseVectorType(); |
| // integer-type |
| case Token::inttype: { |
| auto width = getToken().getIntTypeBitwidth(); |
| if (!width.hasValue()) |
| return (emitError("invalid integer width"), nullptr); |
| if (width.getValue() > IntegerType::kMaxWidth) { |
| emitError(getToken().getLoc(), "integer bitwidth is limited to ") |
| << IntegerType::kMaxWidth << " bits"; |
| return nullptr; |
| } |
| |
| IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; |
| if (Optional<bool> signedness = getToken().getIntTypeSignedness()) |
| signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; |
| |
| auto loc = getEncodedSourceLocation(getToken().getLoc()); |
| consumeToken(Token::inttype); |
| return IntegerType::getChecked(width.getValue(), signSemantics, loc); |
| } |
| |
| // float-type |
| case Token::kw_bf16: |
| consumeToken(Token::kw_bf16); |
| return builder.getBF16Type(); |
| case Token::kw_f16: |
| consumeToken(Token::kw_f16); |
| return builder.getF16Type(); |
| case Token::kw_f32: |
| consumeToken(Token::kw_f32); |
| return builder.getF32Type(); |
| case Token::kw_f64: |
| consumeToken(Token::kw_f64); |
| return builder.getF64Type(); |
| |
| // index-type |
| case Token::kw_index: |
| consumeToken(Token::kw_index); |
| return builder.getIndexType(); |
| |
| // none-type |
| case Token::kw_none: |
| consumeToken(Token::kw_none); |
| return builder.getNoneType(); |
| |
| // extended type |
| case Token::exclamation_identifier: |
| return parseExtendedType(); |
| } |
| } |
| |
| /// Parse a tensor type. |
| /// |
| /// tensor-type ::= `tensor` `<` dimension-list type `>` |
| /// dimension-list ::= dimension-list-ranked | `*x` |
| /// |
| Type Parser::parseTensorType() { |
| consumeToken(Token::kw_tensor); |
| |
| if (parseToken(Token::less, "expected '<' in tensor type")) |
| return nullptr; |
| |
| bool isUnranked; |
| SmallVector<int64_t, 4> dimensions; |
| |
| if (consumeIf(Token::star)) { |
| // This is an unranked tensor type. |
| isUnranked = true; |
| |
| if (parseXInDimensionList()) |
| return nullptr; |
| |
| } else { |
| isUnranked = false; |
| if (parseDimensionListRanked(dimensions)) |
| return nullptr; |
| } |
| |
| // Parse the element type. |
| auto elementTypeLoc = getToken().getLoc(); |
| auto elementType = parseType(); |
| if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) |
| return nullptr; |
| if (!TensorType::isValidElementType(elementType)) |
| return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; |
| |
| if (isUnranked) |
| return UnrankedTensorType::get(elementType); |
| return RankedTensorType::get(dimensions, elementType); |
| } |
| |
| /// Parse a tuple type. |
| /// |
| /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` |
| /// |
| Type Parser::parseTupleType() { |
| consumeToken(Token::kw_tuple); |
| |
| // Parse the '<'. |
| if (parseToken(Token::less, "expected '<' in tuple type")) |
| return nullptr; |
| |
| // Check for an empty tuple by directly parsing '>'. |
| if (consumeIf(Token::greater)) |
| return TupleType::get(getContext()); |
| |
| // Parse the element types and the '>'. |
| SmallVector<Type, 4> types; |
| if (parseTypeListNoParens(types) || |
| parseToken(Token::greater, "expected '>' in tuple type")) |
| return nullptr; |
| |
| return TupleType::get(types, getContext()); |
| } |
| |
| /// Parse a vector type. |
| /// |
| /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` |
| /// non-empty-static-dimension-list ::= decimal-literal `x` |
| /// static-dimension-list |
| /// static-dimension-list ::= (decimal-literal `x`)* |
| /// |
| VectorType Parser::parseVectorType() { |
| consumeToken(Token::kw_vector); |
| |
| if (parseToken(Token::less, "expected '<' in vector type")) |
| return nullptr; |
| |
| SmallVector<int64_t, 4> dimensions; |
| if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) |
| return nullptr; |
| if (dimensions.empty()) |
| return (emitError("expected dimension size in vector type"), nullptr); |
| if (any_of(dimensions, [](int64_t i) { return i <= 0; })) |
| return emitError(getToken().getLoc(), |
| "vector types must have positive constant sizes"), |
| nullptr; |
| |
| // Parse the element type. |
| auto typeLoc = getToken().getLoc(); |
| auto elementType = parseType(); |
| if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) |
| return nullptr; |
| if (!VectorType::isValidElementType(elementType)) |
| return emitError(typeLoc, "vector elements must be int or float type"), |
| nullptr; |
| |
| return VectorType::get(dimensions, elementType); |
| } |
| |
| /// Parse a dimension list of a tensor or memref type. This populates the |
| /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and |
| /// errors out on `?` otherwise. |
| /// |
| /// dimension-list-ranked ::= (dimension `x`)* |
| /// dimension ::= `?` | decimal-literal |
| /// |
| /// When `allowDynamic` is not set, this is used to parse: |
| /// |
| /// static-dimension-list ::= (decimal-literal `x`)* |
| ParseResult |
| Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, |
| bool allowDynamic) { |
| while (getToken().isAny(Token::integer, Token::question)) { |
| if (consumeIf(Token::question)) { |
| if (!allowDynamic) |
| return emitError("expected static shape"); |
| dimensions.push_back(-1); |
| } else { |
| // Hexadecimal integer literals (starting with `0x`) are not allowed in |
| // aggregate type declarations. Therefore, `0xf32` should be processed as |
| // a sequence of separate elements `0`, `x`, `f32`. |
| if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { |
| // We can get here only if the token is an integer literal. Hexadecimal |
| // integer literals can only start with `0x` (`1x` wouldn't lex as a |
| // literal, just `1` would, at which point we don't get into this |
| // branch). |
| assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); |
| dimensions.push_back(0); |
| state.lex.resetPointer(getTokenSpelling().data() + 1); |
| consumeToken(); |
| } else { |
| // Make sure this integer value is in bound and valid. |
| auto dimension = getToken().getUnsignedIntegerValue(); |
| if (!dimension.hasValue()) |
| return emitError("invalid dimension"); |
| dimensions.push_back((int64_t)dimension.getValue()); |
| consumeToken(Token::integer); |
| } |
| } |
| |
| // Make sure we have an 'x' or something like 'xbf32'. |
| if (parseXInDimensionList()) |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| /// Parse an 'x' token in a dimension list, handling the case where the x is |
| /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next |
| /// token. |
| ParseResult Parser::parseXInDimensionList() { |
| if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') |
| return emitError("expected 'x' in dimension list"); |
| |
| // If we had a prefix of 'x', lex the next token immediately after the 'x'. |
| if (getTokenSpelling().size() != 1) |
| state.lex.resetPointer(getTokenSpelling().data() + 1); |
| |
| // Consume the 'x'. |
| consumeToken(Token::bare_identifier); |
| |
| return success(); |
| } |
| |
| // Parse a comma-separated list of dimensions, possibly empty: |
| // stride-list ::= `[` (dimension (`,` dimension)*)? `]` |
| ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { |
| if (!consumeIf(Token::l_square)) |
| return failure(); |
| // Empty list early exit. |
| if (consumeIf(Token::r_square)) |
| return success(); |
| while (true) { |
| if (consumeIf(Token::question)) { |
| dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); |
| } else { |
| // This must be an integer value. |
| int64_t val; |
| if (getToken().getSpelling().getAsInteger(10, val)) |
| return emitError("invalid integer value: ") << getToken().getSpelling(); |
| // Make sure it is not the one value for `?`. |
| if (ShapedType::isDynamic(val)) |
| return emitError("invalid integer value: ") |
| << getToken().getSpelling() |
| << ", use `?` to specify a dynamic dimension"; |
| dimensions.push_back(val); |
| consumeToken(Token::integer); |
| } |
| if (!consumeIf(Token::comma)) |
| break; |
| } |
| if (!consumeIf(Token::r_square)) |
| return failure(); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Attribute parsing. |
| //===----------------------------------------------------------------------===// |
| |
| /// Return the symbol reference referred to by the given token, that is known to |
| /// be an @-identifier. |
| static std::string extractSymbolReference(Token tok) { |
| assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); |
| StringRef nameStr = tok.getSpelling().drop_front(); |
| |
| // Check to see if the reference is a string literal, or a bare identifier. |
| if (nameStr.front() == '"') |
| return tok.getStringValue(); |
| return std::string(nameStr); |
| } |
| |
| /// Parse an arbitrary attribute. |
| /// |
| /// attribute-value ::= `unit` |
| /// | bool-literal |
| /// | integer-literal (`:` (index-type | integer-type))? |
| /// | float-literal (`:` float-type)? |
| /// | string-literal (`:` type)? |
| /// | type |
| /// | `[` (attribute-value (`,` attribute-value)*)? `]` |
| /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` |
| /// | symbol-ref-id (`::` symbol-ref-id)* |
| /// | `dense` `<` attribute-value `>` `:` |
| /// (tensor-type | vector-type) |
| /// | `sparse` `<` attribute-value `,` attribute-value `>` |
| /// `:` (tensor-type | vector-type) |
| /// | `opaque` `<` dialect-namespace `,` hex-string-literal |
| /// `>` `:` (tensor-type | vector-type) |
| /// | extended-attribute |
| /// |
| Attribute Parser::parseAttribute(Type type) { |
| switch (getToken().getKind()) { |
| // Parse an AffineMap or IntegerSet attribute. |
| case Token::kw_affine_map: { |
| consumeToken(Token::kw_affine_map); |
| |
| AffineMap map; |
| if (parseToken(Token::less, "expected '<' in affine map") || |
| parseAffineMapReference(map) || |
| parseToken(Token::greater, "expected '>' in affine map")) |
| return Attribute(); |
| return AffineMapAttr::get(map); |
| } |
| case Token::kw_affine_set: { |
| consumeToken(Token::kw_affine_set); |
| |
| IntegerSet set; |
| if (parseToken(Token::less, "expected '<' in integer set") || |
| parseIntegerSetReference(set) || |
| parseToken(Token::greater, "expected '>' in integer set")) |
| return Attribute(); |
| return IntegerSetAttr::get(set); |
| } |
| |
| // Parse an array attribute. |
| case Token::l_square: { |
| consumeToken(Token::l_square); |
| |
| SmallVector<Attribute, 4> elements; |
| auto parseElt = [&]() -> ParseResult { |
| elements.push_back(parseAttribute()); |
| return elements.back() ? success() : failure(); |
| }; |
| |
| if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) |
| return nullptr; |
| return builder.getArrayAttr(elements); |
| } |
| |
| // Parse a boolean attribute. |
| case Token::kw_false: |
| consumeToken(Token::kw_false); |
| return builder.getBoolAttr(false); |
| case Token::kw_true: |
| consumeToken(Token::kw_true); |
| return builder.getBoolAttr(true); |
| |
| // Parse a dense elements attribute. |
| case Token::kw_dense: |
| return parseDenseElementsAttr(type); |
| |
| // Parse a dictionary attribute. |
| case Token::l_brace: { |
| SmallVector<NamedAttribute, 4> elements; |
| if (parseAttributeDict(elements)) |
| return nullptr; |
| return builder.getDictionaryAttr(elements); |
| } |
| |
| // Parse an extended attribute, i.e. alias or dialect attribute. |
| case Token::hash_identifier: |
| return parseExtendedAttr(type); |
| |
| // Parse floating point and integer attributes. |
| case Token::floatliteral: |
| return parseFloatAttr(type, /*isNegative=*/false); |
| case Token::integer: |
| return parseDecOrHexAttr(type, /*isNegative=*/false); |
| case Token::minus: { |
| consumeToken(Token::minus); |
| if (getToken().is(Token::integer)) |
| return parseDecOrHexAttr(type, /*isNegative=*/true); |
| if (getToken().is(Token::floatliteral)) |
| return parseFloatAttr(type, /*isNegative=*/true); |
| |
| return (emitError("expected constant integer or floating point value"), |
| nullptr); |
| } |
| |
| // Parse a location attribute. |
| case Token::kw_loc: { |
| LocationAttr attr; |
| return failed(parseLocation(attr)) ? Attribute() : attr; |
| } |
| |
| // Parse an opaque elements attribute. |
| case Token::kw_opaque: |
| return parseOpaqueElementsAttr(type); |
| |
| // Parse a sparse elements attribute. |
| case Token::kw_sparse: |
| return parseSparseElementsAttr(type); |
| |
| // Parse a string attribute. |
| case Token::string: { |
| auto val = getToken().getStringValue(); |
| consumeToken(Token::string); |
| // Parse the optional trailing colon type if one wasn't explicitly provided. |
| if (!type && consumeIf(Token::colon) && !(type = parseType())) |
| return Attribute(); |
| |
| return type ? StringAttr::get(val, type) |
| : StringAttr::get(val, getContext()); |
| } |
| |
| // Parse a symbol reference attribute. |
| case Token::at_identifier: { |
| std::string nameStr = extractSymbolReference(getToken()); |
| consumeToken(Token::at_identifier); |
| |
| // Parse any nested references. |
| std::vector<FlatSymbolRefAttr> nestedRefs; |
| while (getToken().is(Token::colon)) { |
| // Check for the '::' prefix. |
| const char *curPointer = getToken().getLoc().getPointer(); |
| consumeToken(Token::colon); |
| if (!consumeIf(Token::colon)) { |
| state.lex.resetPointer(curPointer); |
| consumeToken(); |
| break; |
| } |
| // Parse the reference itself. |
| auto curLoc = getToken().getLoc(); |
| if (getToken().isNot(Token::at_identifier)) { |
| emitError(curLoc, "expected nested symbol reference identifier"); |
| return Attribute(); |
| } |
| |
| std::string nameStr = extractSymbolReference(getToken()); |
| consumeToken(Token::at_identifier); |
| nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); |
| } |
| |
| return builder.getSymbolRefAttr(nameStr, nestedRefs); |
| } |
| |
| // Parse a 'unit' attribute. |
| case Token::kw_unit: |
| consumeToken(Token::kw_unit); |
| return builder.getUnitAttr(); |
| |
| default: |
| // Parse a type attribute. |
| if (Type type = parseType()) |
| return TypeAttr::get(type); |
| return nullptr; |
| } |
| } |
| |
| /// Attribute dictionary. |
| /// |
| /// attribute-dict ::= `{` `}` |
| /// | `{` attribute-entry (`,` attribute-entry)* `}` |
| /// attribute-entry ::= bare-id `=` attribute-value |
| /// |
| ParseResult |
| Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { |
| if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) |
| return failure(); |
| |
| auto parseElt = [&]() -> ParseResult { |
| // We allow keywords as attribute names. |
| if (getToken().isNot(Token::bare_identifier, Token::inttype) && |
| !getToken().isKeyword()) |
| return emitError("expected attribute name"); |
| Identifier nameId = builder.getIdentifier(getTokenSpelling()); |
| consumeToken(); |
| |
| // Try to parse the '=' for the attribute value. |
| if (!consumeIf(Token::equal)) { |
| // If there is no '=', we treat this as a unit attribute. |
| attributes.push_back({nameId, builder.getUnitAttr()}); |
| return success(); |
| } |
| |
| auto attr = parseAttribute(); |
| if (!attr) |
| return failure(); |
| |
| attributes.push_back({nameId, attr}); |
| return success(); |
| }; |
| |
| if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) |
| return failure(); |
| |
| return success(); |
| } |
| |
| /// Parse an extended attribute. |
| /// |
| /// extended-attribute ::= (dialect-attribute | attribute-alias) |
| /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` |
| /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? |
| /// attribute-alias ::= `#` alias-name |
| /// |
| Attribute Parser::parseExtendedAttr(Type type) { |
| Attribute attr = parseExtendedSymbol<Attribute>( |
| *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, |
| [&](StringRef dialectName, StringRef symbolData, |
| llvm::SMLoc loc) -> Attribute { |
| // Parse an optional trailing colon type. |
| Type attrType = type; |
| if (consumeIf(Token::colon) && !(attrType = parseType())) |
| return Attribute(); |
| |
| // If we found a registered dialect, then ask it to parse the attribute. |
| if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { |
| return parseSymbol<Attribute>( |
| symbolData, state.context, state.symbols, [&](Parser &parser) { |
| CustomDialectAsmParser customParser(symbolData, parser); |
| return dialect->parseAttribute(customParser, attrType); |
| }); |
| } |
| |
| // Otherwise, form a new opaque attribute. |
| return OpaqueAttr::getChecked( |
| Identifier::get(dialectName, state.context), symbolData, |
| attrType ? attrType : NoneType::get(state.context), |
| getEncodedSourceLocation(loc)); |
| }); |
| |
| // Ensure that the attribute has the same type as requested. |
| if (attr && type && attr.getType() != type) { |
| emitError("attribute type different than expected: expected ") |
| << type << ", but got " << attr.getType(); |
| return nullptr; |
| } |
| return attr; |
| } |
| |
| /// Parse a float attribute. |
| Attribute Parser::parseFloatAttr(Type type, bool isNegative) { |
| auto val = getToken().getFloatingPointValue(); |
| if (!val.hasValue()) |
| return (emitError("floating point value too large for attribute"), nullptr); |
| consumeToken(Token::floatliteral); |
| if (!type) { |
| // Default to F64 when no type is specified. |
| if (!consumeIf(Token::colon)) |
| type = builder.getF64Type(); |
| else if (!(type = parseType())) |
| return nullptr; |
| } |
| if (!type.isa<FloatType>()) |
| return (emitError("floating point value not valid for specified type"), |
| nullptr); |
| return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); |
| } |
| |
| /// Construct a float attribute bitwise equivalent to the integer literal. |
| static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type, |
| uint64_t value) { |
| // FIXME: bfloat is currently stored as a double internally because it doesn't |
| // have valid APFloat semantics. |
| if (type.isF64() || type.isBF16()) |
| return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); |
| |
| APInt apInt(type.getWidth(), value); |
| if (apInt != value) { |
| p->emitError("hexadecimal float constant out of range for type"); |
| return llvm::None; |
| } |
| return APFloat(type.getFloatSemantics(), apInt); |
| } |
| |
| /// Parse a decimal or a hexadecimal literal, which can be either an integer |
| /// or a float attribute. |
| Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { |
| auto val = getToken().getUInt64IntegerValue(); |
| if (!val.hasValue()) |
| return (emitError("integer constant out of range for attribute"), nullptr); |
| |
| // Remember if the literal is hexadecimal. |
| StringRef spelling = getToken().getSpelling(); |
| auto loc = state.curToken.getLoc(); |
| bool isHex = spelling.size() > 1 && spelling[1] == 'x'; |
| |
| consumeToken(Token::integer); |
| if (!type) { |
| // Default to i64 if not type is specified. |
| if (!consumeIf(Token::colon)) |
| type = builder.getIntegerType(64); |
| else if (!(type = parseType())) |
| return nullptr; |
| } |
| |
| if (auto floatType = type.dyn_cast<FloatType>()) { |
| if (isNegative) |
| return emitError( |
| loc, |
| "hexadecimal float literal should not have a leading minus"), |
| nullptr; |
| if (!isHex) { |
| emitError(loc, "unexpected decimal integer literal for a float attribute") |
| .attachNote() |
| << "add a trailing dot to make the literal a float"; |
| return nullptr; |
| } |
| |
| // Construct a float attribute bitwise equivalent to the integer literal. |
| Optional<APFloat> apVal = |
| buildHexadecimalFloatLiteral(this, floatType, *val); |
| return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); |
| } |
| |
| if (!type.isa<IntegerType>() && !type.isa<IndexType>()) |
| return emitError(loc, "integer literal not valid for specified type"), |
| nullptr; |
| |
| if (isNegative && type.isUnsignedInteger()) { |
| emitError(loc, |
| "negative integer literal not valid for unsigned integer type"); |
| return nullptr; |
| } |
| |
| // Parse the integer literal. |
| int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); |
| APInt apInt(width, *val, isNegative); |
| if (apInt != *val) |
| return emitError(loc, "integer constant out of range for attribute"), |
| nullptr; |
| |
| // Otherwise construct an integer attribute. |
| if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) |
| return emitError(loc, "integer constant out of range for attribute"), |
| nullptr; |
| |
| return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); |
| } |
| |
| /// Parse elements values stored within a hex etring. On success, the values are |
| /// stored into 'result'. |
| static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, |
| std::string &result) { |
| std::string val = tok.getStringValue(); |
| if (val.size() < 2 || val[0] != '0' || val[1] != 'x') |
| return parser.emitError(tok.getLoc(), |
| "elements hex string should start with '0x'"); |
| |
| StringRef hexValues = StringRef(val).drop_front(2); |
| if (!llvm::all_of(hexValues, llvm::isHexDigit)) |
| return parser.emitError(tok.getLoc(), |
| "elements hex string only contains hex digits"); |
| |
| result = llvm::fromHex(hexValues); |
| return success(); |
| } |
| |
| /// Parse an opaque elements attribute. |
| Attribute Parser::parseOpaqueElementsAttr(Type attrType) { |
| consumeToken(Token::kw_opaque); |
| if (parseToken(Token::less, "expected '<' after 'opaque'")) |
| return nullptr; |
| |
| if (getToken().isNot(Token::string)) |
| return (emitError("expected dialect namespace"), nullptr); |
| |
| auto name = getToken().getStringValue(); |
| auto *dialect = builder.getContext()->getRegisteredDialect(name); |
| // TODO(shpeisman): Allow for having an unknown dialect on an opaque |
| // attribute. Otherwise, it can't be roundtripped without having the dialect |
| // registered. |
| if (!dialect) |
| return (emitError("no registered dialect with namespace '" + name + "'"), |
| nullptr); |
| consumeToken(Token::string); |
| |
| if (parseToken(Token::comma, "expected ','")) |
| return nullptr; |
| |
| Token hexTok = getToken(); |
| if (parseToken(Token::string, "elements hex string should start with '0x'") || |
| parseToken(Token::greater, "expected '>'")) |
| return nullptr; |
| auto type = parseElementsLiteralType(attrType); |
| if (!type) |
| return nullptr; |
| |
| std::string data; |
| if (parseElementAttrHexValues(*this, hexTok, data)) |
| return nullptr; |
| return OpaqueElementsAttr::get(dialect, type, data); |
| } |
| |
| namespace { |
| class TensorLiteralParser { |
| public: |
| TensorLiteralParser(Parser &p) : p(p) {} |
| |
| /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
| /// may also parse a tensor literal that is store as a hex string. |
| ParseResult parse(bool allowHex); |
| |
| /// Build a dense attribute instance with the parsed elements and the given |
| /// shaped type. |
| DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); |
| |
| ArrayRef<int64_t> getShape() const { return shape; } |
| |
| private: |
| enum class ElementKind { Boolean, Integer, Float }; |
| |
| /// Return a string to represent the given element kind. |
| const char *getElementKindStr(ElementKind kind) { |
| switch (kind) { |
| case ElementKind::Boolean: |
| return "'boolean'"; |
| case ElementKind::Integer: |
| return "'integer'"; |
| case ElementKind::Float: |
| return "'float'"; |
| } |
| llvm_unreachable("unknown element kind"); |
| } |
| |
| /// Build a Dense Integer attribute for the given type. |
| DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, |
| IntegerType eltTy); |
| |
| /// Build a Dense Float attribute for the given type. |
| DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, |
| FloatType eltTy); |
| |
| /// Build a Dense attribute with hex data for the given type. |
| DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); |
| |
| /// Parse a single element, returning failure if it isn't a valid element |
| /// literal. For example: |
| /// parseElement(1) -> Success, 1 |
| /// parseElement([1]) -> Failure |
| ParseResult parseElement(); |
| |
| /// Parse a list of either lists or elements, returning the dimensions of the |
| /// parsed sub-tensors in dims. For example: |
| /// parseList([1, 2, 3]) -> Success, [3] |
| /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
| /// parseList([[1, 2], 3]) -> Failure |
| /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
| ParseResult parseList(SmallVectorImpl<int64_t> &dims); |
| |
| /// Parse a literal that was printed as a hex string. |
| ParseResult parseHexElements(); |
| |
| Parser &p; |
| |
| /// The shape inferred from the parsed elements. |
| SmallVector<int64_t, 4> shape; |
| |
| /// Storage used when parsing elements, this is a pair of <is_negated, token>. |
| std::vector<std::pair<bool, Token>> storage; |
| |
| /// A flag that indicates the type of elements that have been parsed. |
| Optional<ElementKind> knownEltKind; |
| |
| /// Storage used when parsing elements that were stored as hex values. |
| Optional<Token> hexStorage; |
| }; |
| } // namespace |
| |
| /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser |
| /// may also parse a tensor literal that is store as a hex string. |
| ParseResult TensorLiteralParser::parse(bool allowHex) { |
| // If hex is allowed, check for a string literal. |
| if (allowHex && p.getToken().is(Token::string)) { |
| hexStorage = p.getToken(); |
| p.consumeToken(Token::string); |
| return success(); |
| } |
| // Otherwise, parse a list or an individual element. |
| if (p.getToken().is(Token::l_square)) |
| return parseList(shape); |
| return parseElement(); |
| } |
| |
| /// Build a dense attribute instance with the parsed elements and the given |
| /// shaped type. |
| DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, |
| ShapedType type) { |
| // Check to see if we parsed the literal from a hex string. |
| if (hexStorage.hasValue()) |
| return getHexAttr(loc, type); |
| |
| // Check that the parsed storage size has the same number of elements to the |
| // type, or is a known splat. |
| if (!shape.empty() && getShape() != type.getShape()) { |
| p.emitError(loc) << "inferred shape of elements literal ([" << getShape() |
| << "]) does not match type ([" << type.getShape() << "])"; |
| return nullptr; |
| } |
| |
| // If the type is an integer, build a set of APInt values from the storage |
| // with the correct bitwidth. |
| if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) |
| return getIntAttr(loc, type, intTy); |
| |
| // Otherwise, this must be a floating point type. |
| auto floatTy = type.getElementType().dyn_cast<FloatType>(); |
| if (!floatTy) { |
| p.emitError(loc) << "expected floating-point or integer element type, got " |
| << type.getElementType(); |
| return nullptr; |
| } |
| return getFloatAttr(loc, type, floatTy); |
| } |
| |
| /// Build a Dense Integer attribute for the given type. |
| DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, |
| ShapedType type, |
| IntegerType eltTy) { |
| std::vector<APInt> intElements; |
| intElements.reserve(storage.size()); |
| auto isUintType = type.getElementType().isUnsignedInteger(); |
| for (const auto &signAndToken : storage) { |
| bool isNegative = signAndToken.first; |
| const Token &token = signAndToken.second; |
| auto tokenLoc = token.getLoc(); |
| |
| if (isNegative && isUintType) { |
| p.emitError(tokenLoc) |
| << "expected unsigned integer elements, but parsed negative value"; |
| return nullptr; |
| } |
| |
| // Check to see if floating point values were parsed. |
| if (token.is(Token::floatliteral)) { |
| p.emitError(tokenLoc) |
| << "expected integer elements, but parsed floating-point"; |
| return nullptr; |
| } |
| |
| assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && |
| "unexpected token type"); |
| if (token.isAny(Token::kw_true, Token::kw_false)) { |
| if (!eltTy.isInteger(1)) |
| p.emitError(tokenLoc) |
| << "expected i1 type for 'true' or 'false' values"; |
| APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), |
| /*isSigned=*/false); |
| intElements.push_back(apInt); |
| continue; |
| } |
| |
| // Create APInt values for each element with the correct bitwidth. |
| auto val = token.getUInt64IntegerValue(); |
| if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 |
| : (int64_t)val.getValue() < 0)) { |
| p.emitError(tokenLoc, "integer constant out of range for attribute"); |
| return nullptr; |
| } |
| APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); |
| if (apInt != val.getValue()) |
| return (p.emitError(tokenLoc, "integer constant out of range for type"), |
| nullptr); |
| intElements.push_back(isNegative ? -apInt : apInt); |
| } |
| |
| return DenseElementsAttr::get(type, intElements); |
| } |
| |
| /// Build a Dense Float attribute for the given type. |
| DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, |
| ShapedType type, |
| FloatType eltTy) { |
| std::vector<APFloat> floatValues; |
| floatValues.reserve(storage.size()); |
| for (const auto &signAndToken : storage) { |
| bool isNegative = signAndToken.first; |
| const Token &token = signAndToken.second; |
| |
| // Handle hexadecimal float literals. |
| if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { |
| if (isNegative) { |
| p.emitError(token.getLoc()) |
| << "hexadecimal float literal should not have a leading minus"; |
| return nullptr; |
| } |
| auto val = token.getUInt64IntegerValue(); |
| if (!val.hasValue()) { |
| p.emitError("hexadecimal float constant out of range for attribute"); |
| return nullptr; |
| } |
| Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); |
| if (!apVal) |
| return nullptr; |
| floatValues.push_back(*apVal); |
| continue; |
| } |
| |
| // Check to see if any decimal integers or booleans were parsed. |
| if (!token.is(Token::floatliteral)) { |
| p.emitError() << "expected floating-point elements, but parsed integer"; |
| return nullptr; |
| } |
| |
| // Build the float values from tokens. |
| auto val = token.getFloatingPointValue(); |
| if (!val.hasValue()) { |
| p.emitError("floating point value too large for attribute"); |
| return nullptr; |
| } |
| // Treat BF16 as double because it is not supported in LLVM's APFloat. |
| APFloat apVal(isNegative ? -*val : *val); |
| if (!eltTy.isBF16() && !eltTy.isF64()) { |
| bool unused; |
| apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, |
| &unused); |
| } |
| floatValues.push_back(apVal); |
| } |
| |
| return DenseElementsAttr::get(type, floatValues); |
| } |
| |
| /// Build a Dense attribute with hex data for the given type. |
| DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, |
| ShapedType type) { |
| Type elementType = type.getElementType(); |
| if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) { |
| p.emitError(loc) << "expected floating-point or integer element type, got " |
| << elementType; |
| return nullptr; |
| } |
| |
| std::string data; |
| if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) |
| return nullptr; |
| |
| // Check that the size of the hex data correpsonds to the size of the type, or |
| // a splat of the type. |
| if (static_cast<int64_t>(data.size() * CHAR_BIT) != |
| (type.getNumElements() * elementType.getIntOrFloatBitWidth())) { |
| p.emitError(loc) << "elements hex data size is invalid for provided type: " |
| << type; |
| return nullptr; |
| } |
| |
| return DenseElementsAttr::getFromRawBuffer( |
| type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false); |
| } |
| |
| ParseResult TensorLiteralParser::parseElement() { |
| switch (p.getToken().getKind()) { |
| // Parse a boolean element. |
| case Token::kw_true: |
| case Token::kw_false: |
| case Token::floatliteral: |
| case Token::integer: |
| storage.emplace_back(/*isNegative=*/false, p.getToken()); |
| p.consumeToken(); |
| break; |
| |
| // Parse a signed integer or a negative floating-point element. |
| case Token::minus: |
| p.consumeToken(Token::minus); |
| if (!p.getToken().isAny(Token::floatliteral, Token::integer)) |
| return p.emitError("expected integer or floating point literal"); |
| storage.emplace_back(/*isNegative=*/true, p.getToken()); |
| p.consumeToken(); |
| break; |
| |
| default: |
| return p.emitError("expected element literal of primitive type"); |
| } |
| |
| return success(); |
| } |
| |
| /// Parse a list of either lists or elements, returning the dimensions of the |
| /// parsed sub-tensors in dims. For example: |
| /// parseList([1, 2, 3]) -> Success, [3] |
| /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] |
| /// parseList([[1, 2], 3]) -> Failure |
| /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure |
| ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { |
| p.consumeToken(Token::l_square); |
| |
| auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, |
| const SmallVectorImpl<int64_t> &newDims) -> ParseResult { |
| if (prevDims == newDims) |
| return success(); |
| return p.emitError("tensor literal is invalid; ranks are not consistent " |
| "between elements"); |
| }; |
| |
| bool first = true; |
| SmallVector<int64_t, 4> newDims; |
| unsigned size = 0; |
| auto parseCommaSeparatedList = [&]() -> ParseResult { |
| SmallVector<int64_t, 4> thisDims; |
| if (p.getToken().getKind() == Token::l_square) { |
| if (parseList(thisDims)) |
| return failure(); |
| } else if (parseElement()) { |
| return failure(); |
| } |
| ++size; |
| if (!first) |
| return checkDims(newDims, thisDims); |
| newDims = thisDims; |
| first = false; |
| return success(); |
| }; |
| if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) |
| return failure(); |
| |
| // Return the sublists' dimensions with 'size' prepended. |
| dims.clear(); |
| dims.push_back(size); |
| dims.append(newDims.begin(), newDims.end()); |
| return success(); |
| } |
| |
| /// Parse a dense elements attribute. |
| Attribute Parser::parseDenseElementsAttr(Type attrType) { |
| consumeToken(Token::kw_dense); |
| if (parseToken(Token::less, "expected '<' after 'dense'")) |
| return nullptr; |
| |
| // Parse the literal data. |
| TensorLiteralParser literalParser(*this); |
| if (literalParser.parse(/*allowHex=*/true)) |
| return nullptr; |
| |
| if (parseToken(Token::greater, "expected '>'")) |
| return nullptr; |
| |
| auto typeLoc = getToken().getLoc(); |
| auto type = parseElementsLiteralType(attrType); |
| if (!type) |
| return nullptr; |
| return literalParser.getAttr(typeLoc, type); |
| } |
| |
| /// Shaped type for elements attribute. |
| /// |
| /// elements-literal-type ::= vector-type | ranked-tensor-type |
| /// |
| /// This method also checks the type has static shape. |
| ShapedType Parser::parseElementsLiteralType(Type type) { |
| // If the user didn't provide a type, parse the colon type for the literal. |
| if (!type) { |
| if (parseToken(Token::colon, "expected ':'")) |
| return nullptr; |
| if (!(type = parseType())) |
| return nullptr; |
| } |
| |
| if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { |
| emitError("elements literal must be a ranked tensor or vector type"); |
| return nullptr; |
| } |
| |
| auto sType = type.cast<ShapedType>(); |
| if (!sType.hasStaticShape()) |
| return (emitError("elements literal type must have static shape"), nullptr); |
| |
| return sType; |
| } |
| |
| /// Parse a sparse elements attribute. |
| Attribute Parser::parseSparseElementsAttr(Type attrType) { |
| consumeToken(Token::kw_sparse); |
| if (parseToken(Token::less, "Expected '<' after 'sparse'")) |
| return nullptr; |
| |
| /// Parse the indices. We don't allow hex values here as we may need to use |
| /// the inferred shape. |
| auto indicesLoc = getToken().getLoc(); |
| TensorLiteralParser indiceParser(*this); |
| if (indiceParser.parse(/*allowHex=*/false)) |
| return nullptr; |
| |
| if (parseToken(Token::comma, "expected ','")) |
| return nullptr; |
| |
| /// Parse the values. |
| auto valuesLoc = getToken().getLoc(); |
| TensorLiteralParser valuesParser(*this); |
| if (valuesParser.parse(/*allowHex=*/true)) |
| return nullptr; |
| |
| if (parseToken(Token::greater, "expected '>'")) |
| return nullptr; |
| |
| auto type = parseElementsLiteralType(attrType); |
| if (!type) |
| return nullptr; |
| |
| // If the indices are a splat, i.e. the literal parser parsed an element and |
| // not a list, we set the shape explicitly. The indices are represented by a |
| // 2-dimensional shape where the second dimension is the rank of the type. |
| // Given that the parsed indices is a splat, we know that we only have one |
| // indice and thus one for the first dimension. |
| auto indiceEltType = builder.getIntegerType(64); |
| ShapedType indicesType; |
| if (indiceParser.getShape().empty()) { |
| indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); |
| } else { |
| // Otherwise, set the shape to the one parsed by the literal parser. |
| indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); |
| } |
| auto indices = indiceParser.getAttr(indicesLoc, indicesType); |
| |
| // If the values are a splat, set the shape explicitly based on the number of |
| // indices. The number of indices is encoded in the first dimension of the |
| // indice shape type. |
| auto valuesEltType = type.getElementType(); |
| ShapedType valuesType = |
| valuesParser.getShape().empty() |
| ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) |
| : RankedTensorType::get(valuesParser.getShape(), valuesEltType); |
| auto values = valuesParser.getAttr(valuesLoc, valuesType); |
| |
| /// Sanity check. |
| if (valuesType.getRank() != 1) |
| return (emitError("expected 1-d tensor for values"), nullptr); |
| |
| auto sameShape = (indicesType.getRank() == 1) || |
| (type.getRank() == indicesType.getDimSize(1)); |
| auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); |
| if (!sameShape || !sameElementNum) { |
| emitError() << "expected shape ([" << type.getShape() |
| << "]); inferred shape of indices literal ([" |
| << indicesType.getShape() |
| << "]); inferred shape of values literal ([" |
| << valuesType.getShape() << "])"; |
| return nullptr; |
| } |
| |
| // Build the sparse elements attribute by the indices and values. |
| return SparseElementsAttr::get(type, indices, values); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Location parsing. |
| //===----------------------------------------------------------------------===// |
| |
| /// Parse a location. |
| /// |
| /// location ::= `loc` inline-location |
| /// inline-location ::= '(' location-inst ')' |
| /// |
| ParseResult Parser::parseLocation(LocationAttr &loc) { |
| // Check for 'loc' identifier. |
| if (parseToken(Token::kw_loc, "expected 'loc' keyword")) |
| return emitError(); |
| |
| // Parse the inline-location. |
| if (parseToken(Token::l_paren, "expected '(' in inline location") || |
| parseLocationInstance(loc) || |
| parseToken(Token::r_paren, "expected ')' in inline location")) |
| return failure(); |
| return success(); |
| } |
| |
| /// Specific location instances. |
| /// |
| /// location-inst ::= filelinecol-location | |
| /// name-location | |
| /// callsite-location | |
| /// fused-location | |
| /// unknown-location |
| /// filelinecol-location ::= string-literal ':' integer-literal |
| /// ':' integer-literal |
| /// name-location ::= string-literal |
| /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' |
| /// fused-location ::= fused ('<' attribute-value '>')? |
| /// '[' location-inst (location-inst ',')* ']' |
| /// unknown-location ::= 'unknown' |
| /// |
| ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { |
| consumeToken(Token::bare_identifier); |
| |
| // Parse the '('. |
| if (parseToken(Token::l_paren, "expected '(' in callsite location")) |
| return failure(); |
| |
| // Parse the callee location. |
| LocationAttr calleeLoc; |
| if (parseLocationInstance(calleeLoc)) |
| return failure(); |
| |
| // Parse the 'at'. |
| if (getToken().isNot(Token::bare_identifier) || |
| getToken().getSpelling() != "at") |
| return emitError("expected 'at' in callsite location"); |
| consumeToken(Token::bare_identifier); |
| |
| // Parse the caller location. |
| LocationAttr callerLoc; |
| if (parseLocationInstance(callerLoc)) |
| return failure(); |
| |
| // Parse the ')'. |
| if (parseToken(Token::r_paren, "expected ')' in callsite location")) |
| return failure(); |
| |
| // Return the callsite location. |
| loc = CallSiteLoc::get(calleeLoc, callerLoc); |
| return success(); |
| } |
| |
| ParseResult Parser::parseFusedLocation(LocationAttr &loc) { |
| consumeToken(Token::bare_identifier); |
| |
| // Try to parse the optional metadata. |
| Attribute metadata; |
| if (consumeIf(Token::less)) { |
| metadata = parseAttribute(); |
| if (!metadata) |
| return emitError("expected valid attribute metadata"); |
| // Parse the '>' token. |
| if (parseToken(Token::greater, |
| "expected '>' after fused location metadata")) |
| return failure(); |
| } |
| |
| SmallVector<Location, 4> locations; |
| auto parseElt = [&] { |
| LocationAttr newLoc; |
| if (parseLocationInstance(newLoc)) |
| return failure(); |
| locations.push_back(newLoc); |
| return success(); |
| }; |
| |
| if (parseToken(Token::l_square, "expected '[' in fused location") || |
| parseCommaSeparatedList(parseElt) || |
| parseToken(Token::r_square, "expected ']' in fused location")) |
| return failure(); |
| |
| // Return the fused location. |
| loc = FusedLoc::get(locations, metadata, getContext()); |
| return success(); |
| } |
| |
| ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { |
| auto *ctx = getContext(); |
| auto str = getToken().getStringValue(); |
| consumeToken(Token::string); |
| |
| // If the next token is ':' this is a filelinecol location. |
| if (consumeIf(Token::colon)) { |
| // Parse the line number. |
| if (getToken().isNot(Token::integer)) |
| return emitError("expected integer line number in FileLineColLoc"); |
| auto line = getToken().getUnsignedIntegerValue(); |
| if (!line.hasValue()) |
| return emitError("expected integer line number in FileLineColLoc"); |
| consumeToken(Token::integer); |
| |
| // Parse the ':'. |
| if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) |
| return failure(); |
| |
| // Parse the column number. |
| if (getToken().isNot(Token::integer)) |
| return emitError("expected integer column number in FileLineColLoc"); |
| auto column = getToken().getUnsignedIntegerValue(); |
| if (!column.hasValue()) |
| return emitError("expected integer column number in FileLineColLoc"); |
| consumeToken(Token::integer); |
| |
| loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); |
| return success(); |
| } |
| |
| // Otherwise, this is a NameLoc. |
| |
| // Check for a child location. |
| if (consumeIf(Token::l_paren)) { |
| auto childSourceLoc = getToken().getLoc(); |
| |
| // Parse the child location. |
| LocationAttr childLoc; |
| if (parseLocationInstance(childLoc)) |
| return failure(); |
| |
| // The child must not be another NameLoc. |
| if (childLoc.isa<NameLoc>()) |
| return emitError(childSourceLoc, |
| "child of NameLoc cannot be another NameLoc"); |
| loc = NameLoc::get(Identifier::get(str, ctx), childLoc); |
| |
| // Parse the closing ')'. |
| if (parseToken(Token::r_paren, |
| "expected ')' after child location of NameLoc")) |
| return failure(); |
| } else { |
| loc = NameLoc::get(Identifier::get(str, ctx), ctx); |
| } |
| |
| return success(); |
| } |
| |
| ParseResult Parser::parseLocationInstance(LocationAttr &loc) { |
| // Handle either name or filelinecol locations. |
| if (getToken().is(Token::string)) |
| return parseNameOrFileLineColLocation(loc); |
| |
| // Bare tokens required for other cases. |
| if (!getToken().is(Token::bare_identifier)) |
| return emitError("expected location instance"); |
| |
| // Check for the 'callsite' signifying a callsite location. |
| if (getToken().getSpelling() == "callsite") |
| return parseCallSiteLocation(loc); |
| |
| // If the token is 'fused', then this is a fused location. |
| if (getToken().getSpelling() == "fused") |
| return parseFusedLocation(loc); |
| |
| // Check for a 'unknown' for an unknown location. |
| if (getToken().getSpelling() == "unknown") { |
| consumeToken(Token::bare_identifier); |
| loc = UnknownLoc::get(getContext()); |
| return success(); |
| } |
| |
| return emitError("expected location instance"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Affine parsing. |
| //===----------------------------------------------------------------------===// |
| |
| /// Lower precedence ops (all at the same precedence level). LNoOp is false in |
| /// the boolean sense. |
| enum AffineLowPrecOp { |
| /// Null value. |
| LNoOp, |
| Add, |
| Sub |
| }; |
| |
| /// Higher precedence ops - all at the same precedence level. HNoOp is false |
| /// in the boolean sense. |
| enum AffineHighPrecOp { |
| /// Null value. |
| HNoOp, |
| Mul, |
| FloorDiv, |
| CeilDiv, |
| Mod |
| }; |
| |
| namespace { |
| /// This is a specialized parser for affine structures (affine maps, affine |
| /// expressions, and integer sets), maintaining the state transient to their |
| /// bodies. |
| class AffineParser : public Parser { |
| public: |
| AffineParser(ParserState &state, bool allowParsingSSAIds = false, |
| function_ref<ParseResult(bool)> parseElement = nullptr) |
| : Parser(state), allowParsingSSAIds(allowParsingSSAIds), |
| parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} |
| |
| AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); |
| ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); |
| IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); |
| ParseResult parseAffineMapOfSSAIds(AffineMap &map, |
| OpAsmParser::Delimiter delimiter); |
| void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, |
| unsigned &numDims); |
| |
| private: |
| // Binary affine op parsing. |
| AffineLowPrecOp consumeIfLowPrecOp(); |
| AffineHighPrecOp consumeIfHighPrecOp(); |
| |
| // Identifier lists for polyhedral structures. |
| ParseResult parseDimIdList(unsigned &numDims); |
| ParseResult parseSymbolIdList(unsigned &numSymbols); |
| ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, |
| unsigned &numSymbols); |
| ParseResult parseIdentifierDefinition(AffineExpr idExpr); |
| |
| AffineExpr parseAffineExpr(); |
| AffineExpr parseParentheticalExpr(); |
| AffineExpr parseNegateExpression(AffineExpr lhs); |
| AffineExpr parseIntegerExpr(); |
| AffineExpr parseBareIdExpr(); |
| AffineExpr parseSSAIdExpr(bool isSymbol); |
| AffineExpr parseSymbolSSAIdExpr(); |
| |
| AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, |
| AffineExpr rhs, SMLoc opLoc); |
| AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, |
| AffineExpr rhs); |
| AffineExpr parseAffineOperandExpr(AffineExpr lhs); |
| AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); |
| AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, |
| SMLoc llhsOpLoc); |
| AffineExpr parseAffineConstraint(bool *isEq); |
| |
| private: |
| bool allowParsingSSAIds; |
| function_ref<ParseResult(bool)> parseElement; |
| unsigned numDimOperands; |
| unsigned numSymbolOperands; |
| SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; |
| }; |
| } // end anonymous namespace |
| |
| /// Create an affine binary high precedence op expression (mul's, div's, mod). |
| /// opLoc is the location of the op token to be used to report errors |
| /// for non-conforming expressions. |
| AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, |
| AffineExpr lhs, AffineExpr rhs, |
| SMLoc opLoc) { |
| // TODO: make the error location info accurate. |
| switch (op) { |
| case Mul: |
| if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { |
| emitError(opLoc, "non-affine expression: at least one of the multiply " |
| "operands has to be either a constant or symbolic"); |
| return nullptr; |
| } |
| return lhs * rhs; |
| case FloorDiv: |
| if (!rhs.isSymbolicOrConstant()) { |
| emitError(opLoc, "non-affine expression: right operand of floordiv " |
| "has to be either a constant or symbolic"); |
| return nullptr; |
| } |
| return lhs.floorDiv(rhs); |
| case CeilDiv: |
| if (!rhs.isSymbolicOrConstant()) { |
| emitError(opLoc, "non-affine expression: right operand of ceildiv " |
| "has to be either a constant or symbolic"); |
| return nullptr; |
| } |
| return lhs.ceilDiv(rhs); |
| case Mod: |
| if (!rhs.isSymbolicOrConstant()) { |
| emitError(opLoc, "non-affine expression: right operand of mod " |
| "has to be either a constant or symbolic"); |
| return nullptr; |
| } |
| return lhs % rhs; |
| case HNoOp: |
| llvm_unreachable("can't create affine expression for null high prec op"); |
| return nullptr; |
| } |
| llvm_unreachable("Unknown AffineHighPrecOp"); |
| } |
| |
| /// Create an affine binary low precedence op expression (add, sub). |
| AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, |
| AffineExpr lhs, AffineExpr rhs) { |
| switch (op) { |
| case AffineLowPrecOp::Add: |
| return lhs + rhs; |
| case AffineLowPrecOp::Sub: |
| return lhs - rhs; |
| case AffineLowPrecOp::LNoOp: |
| llvm_unreachable("can't create affine expression for null low prec op"); |
| return nullptr; |
| } |
| llvm_unreachable("Unknown AffineLowPrecOp"); |
| } |
| |
| /// Consume this token if it is a lower precedence affine op (there are only |
| /// two precedence levels). |
| AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { |
| switch (getToken().getKind()) { |
| case Token::plus: |
| consumeToken(Token::plus); |
| return AffineLowPrecOp::Add; |
| case Token::minus: |
| consumeToken(Token::minus); |
| return AffineLowPrecOp::Sub; |
| default: |
| return AffineLowPrecOp::LNoOp; |
| } |
| } |
| |
| /// Consume this token if it is a higher precedence affine op (there are only |
| /// two precedence levels) |
| AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { |
| switch (getToken().getKind()) { |
| case Token::star: |
| consumeToken(Token::star); |
| return Mul; |
| case Token::kw_floordiv: |
| consumeToken(Token::kw_floordiv); |
| return FloorDiv; |
| case Token::kw_ceildiv: |
| consumeToken(Token::kw_ceildiv); |
| return CeilDiv; |
| case Token::kw_mod: |
| consumeToken(Token::kw_mod); |
| return Mod; |
| default: |
| return HNoOp; |
| } |
| } |
| |
| /// Parse a high precedence op expression list: mul, div, and mod are high |
| /// precedence binary ops, i.e., parse a |
| /// expr_1 op_1 expr_2 op_2 ... expr_n |
| /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). |
| /// All affine binary ops are left associative. |
| /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is |
| /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is |
| /// null. llhsOpLoc is the location of the llhsOp token that will be used to |
| /// report an error for non-conforming expressions. |
| AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, |
| AffineHighPrecOp llhsOp, |
| SMLoc llhsOpLoc) { |
| AffineExpr lhs = parseAffineOperandExpr(llhs); |
| if (!lhs) |
| return nullptr; |
| |
| // Found an LHS. Parse the remaining expression. |
| auto opLoc = getToken().getLoc(); |
| if (AffineHighPrecOp op = consumeIfHighPrecOp()) { |
| if (llhs) { |
| AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); |
| if (!expr) |
| return nullptr; |
| return parseAffineHighPrecOpExpr(expr, op, opLoc); |
| } |
| // No LLHS, get RHS |
| return parseAffineHighPrecOpExpr(lhs, op, opLoc); |
| } |
| |
| // This is the last operand in this expression. |
| if (llhs) |
| return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); |
| |
| // No llhs, 'lhs' itself is the expression. |
| return lhs; |
| } |
| |
| /// Parse an affine expression inside parentheses. |
| /// |
| /// affine-expr ::= `(` affine-expr `)` |
| AffineExpr AffineParser::parseParentheticalExpr() { |
| if (parseToken(Token::l_paren, "expected '('")) |
| return nullptr; |
| if (getToken().is(Token::r_paren)) |
| return (emitError("no expression inside parentheses"), nullptr); |
| |
| auto expr = parseAffineExpr(); |
| if (!expr) |
| return nullptr; |
| if (parseToken(Token::r_paren, "expected ')'")) |
| return nullptr; |
| |
| return expr; |
| } |
| |
| /// Parse the negation expression. |
| /// |
| /// affine-expr ::= `-` affine-expr |
| AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { |
| if (parseToken(Token::minus, "expected '-'")) |
| return nullptr; |
| |
| AffineExpr operand = parseAffineOperandExpr(lhs); |
| // Since negation has the highest precedence of all ops (including high |
| // precedence ops) but lower than parentheses, we are only going to use |
| // parseAffineOperandExpr instead of parseAffineExpr here. |
| if (!operand) |
| // Extra error message although parseAffineOperandExpr would have |
| // complained. Leads to a better diagnostic. |
| return (emitError("missing operand of negation"), nullptr); |
| return (-1) * operand; |
| } |
| |
| /// Parse a bare id that may appear in an affine expression. |
| /// |
| /// affine-expr ::= bare-id |
| AffineExpr AffineParser::parseBareIdExpr() { |
| if (getToken().isNot(Token::bare_identifier)) |
| return (emitError("expected bare identifier"), nullptr); |
| |
| StringRef sRef = getTokenSpelling(); |
| for (auto entry : dimsAndSymbols) { |
| if (entry.first == sRef) { |
| consumeToken(Token::bare_identifier); |
| return entry.second; |
| } |
| } |
| |
| return (emitError("use of undeclared identifier"), nullptr); |
| } |
| |
| /// Parse an SSA id which may appear in an affine expression. |
| AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { |
| if (!allowParsingSSAIds) |
| return (emitError("unexpected ssa identifier"), nullptr); |
| if (getToken().isNot(Token::percent_identifier)) |
| return (emitError("expected ssa identifier"), nullptr); |
| auto name = getTokenSpelling(); |
| // Check if we already parsed this SSA id. |
| for (auto entry : dimsAndSymbols) { |
| if (entry.first == name) { |
| consumeToken(Token::percent_identifier); |
| return entry.second; |
| } |
| } |
| // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. |
| if (parseElement(isSymbol)) |
| return (emitError("failed to parse ssa identifier"), nullptr); |
| auto idExpr = isSymbol |
| ? getAffineSymbolExpr(numSymbolOperands++, getContext()) |
| : getAffineDimExpr(numDimOperands++, getContext()); |
| dimsAndSymbols.push_back({name, idExpr}); |
| return idExpr; |
| } |
| |
| AffineExpr AffineParser::parseSymbolSSAIdExpr() { |
| if (parseToken(Token::kw_symbol, "expected symbol keyword") || |
| parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) |
| return nullptr; |
| AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); |
| if (!symbolExpr) |
| return nullptr; |
| if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) |
| return nullptr; |
| return symbolExpr; |
| } |
| |
| /// Parse a positive integral constant appearing in an affine expression. |
| /// |
| /// affine-expr ::= integer-literal |
| AffineExpr AffineParser::parseIntegerExpr() { |
| auto val = getToken().getUInt64IntegerValue(); |
| if (!val.hasValue() || (int64_t)val.getValue() < 0) |
| return (emitError("constant too large for index"), nullptr); |
| |
| consumeToken(Token::integer); |
| return builder.getAffineConstantExpr((int64_t)val.getValue()); |
| } |
| |
| /// Parses an expression that can be a valid operand of an affine expression. |
| /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary |
| /// operator, the rhs of which is being parsed. This is used to determine |
| /// whether an error should be emitted for a missing right operand. |
| // Eg: for an expression without parentheses (like i + j + k + l), each |
| // of the four identifiers is an operand. For i + j*k + l, j*k is not an |
| // operand expression, it's an op expression and will be parsed via |
| // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and |
| // -l are valid operands that will be parsed by this function. |
| AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { |
| switch (getToken().getKind()) { |
| case Token::bare_identifier: |
| return parseBareIdExpr(); |
| case Token::kw_symbol: |
| return parseSymbolSSAIdExpr(); |
| case Token::percent_identifier: |
| return parseSSAIdExpr(/*isSymbol=*/false); |
| case Token::integer: |
| return parseIntegerExpr(); |
| case Token::l_paren: |
| return parseParentheticalExpr(); |
| case Token::minus: |
| return parseNegateExpression(lhs); |
| case Token::kw_ceildiv: |
| case Token::kw_floordiv: |
| case Token::kw_mod: |
| case Token::plus: |
| case Token::star: |
| if (lhs) |
| emitError("missing right operand of binary operator"); |
| else |
| emitError("missing left operand of binary operator"); |
| return nullptr; |
| default: |
| if (lhs) |
| emitError("missing right operand of binary operator"); |
| else |
| emitError("expected affine expression"); |
| return nullptr; |
| } |
| } |
| |
| /// Parse affine expressions that are bare-id's, integer constants, |
| /// parenthetical affine expressions, and affine op expressions that are a |
| /// composition of those. |
| /// |
| /// All binary op's associate from left to right. |
| /// |
| /// {add, sub} have lower precedence than {mul, div, and mod}. |
| /// |
| /// Add, sub'are themselves at the same precedence level. Mul, floordiv, |
| /// ceildiv, and mod are at the same higher precedence level. Negation has |
| /// higher precedence than any binary op. |
| /// |
| /// llhs: the affine expression appearing on the left of the one being parsed. |
| /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, |
| /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned |
| /// if llhs is non-null; otherwise lhs is returned. This is to deal with left |
| /// associativity. |
| /// |
| /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function |
| /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where |
| /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). |
| AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, |
| AffineLowPrecOp llhsOp) { |
| AffineExpr lhs; |
| if (!(lhs = parseAffineOperandExpr(llhs))) |
| return nullptr; |
| |
| // Found an LHS. Deal with the ops. |
| if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { |
| if (llhs) { |
| AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); |
| return parseAffineLowPrecOpExpr(sum, lOp); |
| } |
| // No LLHS, get RHS and form the expression. |
| return parseAffineLowPrecOpExpr(lhs, lOp); |
| } |
| auto opLoc = getToken().getLoc(); |
| if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { |
| // We have a higher precedence op here. Get the rhs operand for the llhs |
| // through parseAffineHighPrecOpExpr. |
| AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); |
| if (!highRes) |
| return nullptr; |
| |
| // If llhs is null, the product forms the first operand of the yet to be |
| // found expression. If non-null, the op to associate with llhs is llhsOp. |
| AffineExpr expr = |
| llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; |
| |
| // Recurse for subsequent low prec op's after the affine high prec op |
| // expression. |
| if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) |
| return parseAffineLowPrecOpExpr(expr, nextOp); |
| return expr; |
| } |
| // Last operand in the expression list. |
| if (llhs) |
| return getAffineBinaryOpExpr(llhsOp, llhs, lhs); |
| // No llhs, 'lhs' itself is the expression. |
| return lhs; |
| } |
| |
| /// Parse an affine expression. |
| /// affine-expr ::= `(` affine-expr `)` |
| /// | `-` affine-expr |
| /// | affine-expr `+` affine-expr |
| /// | affine-expr `-` affine-expr |
| /// | affine-expr `*` affine-expr |
| /// | affine-expr `floordiv` affine-expr |
| /// | affine-expr `ceildiv` affine-expr |
| /// | affine-expr `mod` affine-expr |
| /// | bare-id |
| /// | integer-literal |
| /// |
| /// Additional conditions are checked depending on the production. For eg., |
| /// one of the operands for `*` has to be either constant/symbolic; the second |
| /// operand for floordiv, ceildiv, and mod has to be a positive integer. |
| AffineExpr AffineParser::parseAffineExpr() { |
| return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); |
| } |
| |
| /// Parse a dim or symbol from the lists appearing before the actual |
| /// expressions of the affine map. Update our state to store the |
| /// dimensional/symbolic identifier. |
| ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { |
| if (getToken().isNot(Token::bare_identifier)) |
| return emitError("expected bare identifier"); |
| |
| auto name = getTokenSpelling(); |
| for (auto entry : dimsAndSymbols) { |
| if (entry.first == name) |
| return emitError("redefinition of identifier '" + name + "'"); |
| } |
| consumeToken(Token::bare_identifier); |
| |
| dimsAndSymbols.push_back({name, idExpr}); |
| return success(); |
| } |
| |
| /// Parse the list of dimensional identifiers to an affine map. |
| ParseResult AffineParser::parseDimIdList(unsigned &numDims) { |
| if (parseToken(Token::l_paren, |
| "expected '(' at start of dimensional identifiers list")) { |
| return failure(); |
| } |
| |
| auto parseElt = [&]() -> ParseResult { |
| auto dimension = getAffineDimExpr(numDims++, getContext()); |
| return parseIdentifierDefinition(dimension); |
| }; |
| return parseCommaSeparatedListUntil(Token::r_paren, parseElt); |
| } |
| |
| /// Parse the list of symbolic identifiers to an affine map. |
| ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { |
| consumeToken(Token::l_square); |
| auto parseElt = [&]() -> ParseResult { |
| auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); |
| return parseIdentifierDefinition(symbol); |
| }; |
| return parseCommaSeparatedListUntil(Token::r_square, parseElt); |
| } |
| |
| /// Parse the list of symbolic identifiers to an affine map. |
| ParseResult |
| AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, |
| unsigned &numSymbols) { |
| if (parseDimIdList(numDims)) { |
| return failure(); |
| } |
| if (!getToken().is(Token::l_square)) { |
| numSymbols = 0; |
| return success(); |
| } |
| return parseSymbolIdList(numSymbols); |
| } |
| |
| /// Parses an ambiguous affine map or integer set definition inline. |
| ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, |
| IntegerSet &set) { |
| unsigned numDims = 0, numSymbols = 0; |
| |
| // List of dimensional and optional symbol identifiers. |
| if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { |
| return failure(); |
| } |
| |
| // This is needed for parsing attributes as we wouldn't know whether we would |
| // be parsing an integer set attribute or an affine map attribute. |
| bool isArrow = getToken().is(Token::arrow); |
| bool isColon = getToken().is(Token::colon); |
| if (!isArrow && !isColon) { |
| return emitError("expected '->' or ':'"); |
| } else if (isArrow) { |
| parseToken(Token::arrow, "expected '->' or '['"); |
| map = parseAffineMapRange(numDims, numSymbols); |
| return map ? success() : failure(); |
| } else if (parseToken(Token::colon, "expected ':' or '['")) { |
| return failure(); |
| } |
| |
| if ((set = parseIntegerSetConstraints(numDims, numSymbols))) |
| return success(); |
| |
| return failure(); |
| } |
| |
| /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. |
| ParseResult |
| AffineParser::parseAffineMapOfSSAIds(AffineMap &map, |
| OpAsmParser::Delimiter delimiter) { |
| Token::Kind rightToken; |
| switch (delimiter) { |
| case OpAsmParser::Delimiter::Square: |
| if (parseToken(Token::l_square, "expected '['")) |
| return failure(); |
| rightToken = Token::r_square; |
| break; |
| case OpAsmParser::Delimiter::Paren: |
| if (parseToken(Token::l_paren, "expected '('")) |
| return failure(); |
| rightToken = Token::r_paren; |
| break; |
| default: |
| return emitError("unexpected delimiter"); |
| } |
| |
| SmallVector<AffineExpr, 4> exprs; |
| auto parseElt = [&]() -> ParseResult { |
| auto elt = parseAffineExpr(); |
| exprs.push_back(elt); |
| return elt ? success() : failure(); |
| }; |
| |
| // Parse a multi-dimensional affine expression (a comma-separated list of |
| // 1-d affine expressions); the list cannot be empty. Grammar: |
| // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) |
| if (parseCommaSeparatedListUntil(rightToken, parseElt, |
| /*allowEmptyList=*/true)) |
| return failure(); |
| // Parsed a valid affine map. |
| if (exprs.empty()) |
| map = AffineMap::get(getContext()); |
| else |
| map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, |
| exprs); |
| return success(); |
| } |
| |
| /// Parse the range and sizes affine map definition inline. |
| /// |
| /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr |
| /// |
| /// multi-dim-affine-expr ::= `(` `)` |
| /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` |
| AffineMap AffineParser::parseAffineMapRange(unsigned numDims, |
| unsigned numSymbols) { |
| parseToken(Token::l_paren, "expected '(' at start of affine map range"); |
| |
| SmallVector<AffineExpr, 4> exprs; |
| auto parseElt = [&]() -> ParseResult { |
| auto elt = parseAffineExpr(); |
| ParseResult res = elt ? success() : failure(); |
| exprs.push_back(elt); |
| return res; |
| }; |
| |
| // Parse a multi-dimensional affine expression (a comma-separated list of |
| // 1-d affine expressions); the list cannot be empty. Grammar: |
| // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) |
| if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) |
| return AffineMap(); |
| |
| if (exprs.empty()) |
| return AffineMap::get(getContext()); |
| |
| // Parsed a valid affine map. |
| return AffineMap::get(numDims, numSymbols, exprs); |
| } |
| |
| /// Parse an affine constraint. |
| /// affine-constraint ::= affine-expr `>=` `0` |
| /// | affine-expr `==` `0` |
| /// |
| /// isEq is set to true if the parsed constraint is an equality, false if it |
| /// is an inequality (greater than or equal). |
| /// |
| AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { |
| AffineExpr expr = parseAffineExpr(); |
| if (!expr) |
| return nullptr; |
| |
| if (consumeIf(Token::greater) && consumeIf(Token::equal) && |
| getToken().is(Token::integer)) { |
| auto dim = getToken().getUnsignedIntegerValue(); |
| if (dim.hasValue() && dim.getValue() == 0) { |
| consumeToken(Token::integer); |
| *isEq = false; |
| return expr; |
| } |
| return (emitError("expected '0' after '>='"), nullptr); |
| } |
| |
| if (consumeIf(Token::equal) && consumeIf(Token::equal) && |
| getToken().is(Token::integer)) { |
| auto dim = getToken().getUnsignedIntegerValue(); |
| if (dim.hasValue() && dim.getValue() == 0) { |
| consumeToken(Token::integer); |
| *isEq = true; |
| return expr; |
| } |
| return (emitError("expected '0' after '=='"), nullptr); |
| } |
| |
| return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), |
| nullptr); |
| } |
| |
| /// Parse the constraints that are part of an integer set definition. |
| /// integer-set-inline |
| /// ::= dim-and-symbol-id-lists `:` |
| /// '(' affine-constraint-conjunction? ')' |
| /// affine-constraint-conjunction ::= affine-constraint (`,` |
| /// affine-constraint)* |
| /// |
| IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, |
| unsigned numSymbols) { |
| if (parseToken(Token::l_paren, |
| "expected '(' at start of integer set constraint list")) |
| return IntegerSet(); |
| |
| SmallVector<AffineExpr, 4> constraints; |
| SmallVector<bool, 4> isEqs; |
| auto parseElt = [&]() -> ParseResult { |
| bool isEq; |
| auto elt = parseAffineConstraint(&isEq); |
| ParseResult res = elt ? success() : failure(); |
| if (elt) { |
| constraints.push_back(elt); |
| isEqs.push_back(isEq); |
| } |
| return res; |
| }; |
| |
| // Parse a list of affine constraints (comma-separated). |
| if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) |
| return IntegerSet(); |
| |
| // If no constraints were parsed, then treat this as a degenerate 'true' case. |
| if (constraints.empty()) { |
| /* 0 == 0 */ |
| auto zero = getAffineConstantExpr(0, getContext()); |
| return IntegerSet::get(numDims, numSymbols, zero, true); |
| } |
| |
| // Parsed a valid integer set. |
| return IntegerSet::get(numDims, numSymbols, constraints, isEqs); |
| } |
| |
| /// Parse an ambiguous reference to either and affine map or an integer set. |
| ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, |
| IntegerSet &set) { |
| return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); |
| } |
| ParseResult Parser::parseAffineMapReference(AffineMap &map) { |
| llvm::SMLoc curLoc = getToken().getLoc(); |
| IntegerSet set; |
| if (parseAffineMapOrIntegerSetReference(map, set)) |
| return failure(); |
| if (set) |
| return emitError(curLoc, "expected AffineMap, but got IntegerSet"); |
| return success(); |
| } |
| ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { |
| llvm::SMLoc curLoc = getToken().getLoc(); |
| AffineMap map; |
| if (parseAffineMapOrIntegerSetReference(map, set)) |
| return failure(); |
| if (map) |
| return emitError(curLoc, "expected IntegerSet, but got AffineMap"); |
| return success(); |
| } |
| |
| /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to |
| /// parse SSA value uses encountered while parsing affine expressions. |
| ParseResult |
| Parser::parseAffineMapOfSSAIds(AffineMap &map, |
| function_ref<ParseResult(bool)> parseElement, |
| OpAsmParser::Delimiter delimiter) { |
| return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) |
| .parseAffineMapOfSSAIds(map, delimiter); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OperationParser |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class provides support for parsing operations and regions of |
| /// operations. |
| class OperationParser : public Parser { |
| public: |
| OperationParser(ParserState &state, ModuleOp moduleOp) |
| : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { |
| } |
| |
| ~OperationParser(); |
| |
| /// After parsing is finished, this function must be called to see if there |
| /// are any remaining issues. |
| ParseResult finalize(); |
| |
| //===--------------------------------------------------------------------===// |
| // SSA Value Handling |
| //===--------------------------------------------------------------------===// |
| |
| /// This represents a use of an SSA value in the program. The first two |
| /// entries in the tuple are the name and result number of a reference. The |
| /// third is the location of the reference, which is used in case this ends |
| /// up being a use of an undefined value. |
| struct SSAUseInfo { |
| StringRef name; // Value name, e.g. %42 or %abc |
| unsigned number; // Number, specified with #12 |
| SMLoc loc; // Location of first definition or use. |
| }; |
| |
| /// Push a new SSA name scope to the parser. |
| void pushSSANameScope(bool isIsolated); |
| |
| /// Pop the last SSA name scope from the parser. |
| ParseResult popSSANameScope(); |
| |
| /// Register a definition of a value with the symbol table. |
| ParseResult addDefinition(SSAUseInfo useInfo, Value value); |
| |
| /// Parse an optional list of SSA uses into 'results'. |
| ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); |
| |
| /// Parse a single SSA use into 'result'. |
| ParseResult parseSSAUse(SSAUseInfo &result); |
| |
| /// Given a reference to an SSA value and its type, return a reference. This |
| /// returns null on failure. |
| Value resolveSSAUse(SSAUseInfo useInfo, Type type); |
| |
| ParseResult parseSSADefOrUseAndType( |
| const std::function<ParseResult(SSAUseInfo, Type)> &action); |
| |
| ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); |
| |
| /// Return the location of the value identified by its name and number if it |
| /// has been already reference. |
| Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { |
| auto &values = isolatedNameScopes.back().values; |
| if (!values.count(name) || number >= values[name].size()) |
| return {}; |
| if (values[name][number].first) |
| return values[name][number].second; |
| return {}; |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Operation Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse an operation instance. |
| ParseResult parseOperation(); |
| |
| /// Parse a single operation successor and its operand list. |
| ParseResult parseSuccessorAndUseList(Block *&dest, |
| SmallVectorImpl<Value> &operands); |
| |
| /// Parse a comma-separated list of operation successors in brackets. |
| ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations, |
| SmallVectorImpl<SmallVector<Value, 4>> &operands); |
| |
| /// Parse an operation instance that is in the generic form. |
| Operation *parseGenericOperation(); |
| |
| /// Parse an operation instance that is in the generic form and insert it at |
| /// the provided insertion point. |
| Operation *parseGenericOperation(Block *insertBlock, |
| Block::iterator insertPt); |
| |
| /// Parse an operation instance that is in the op-defined custom form. |
| Operation *parseCustomOperation(); |
| |
| //===--------------------------------------------------------------------===// |
| // Region Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a region into 'region' with the provided entry block arguments. |
| /// 'isIsolatedNameScope' indicates if the naming scope of this region is |
| /// isolated from those above. |
| ParseResult parseRegion(Region ®ion, |
| ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, |
| bool isIsolatedNameScope = false); |
| |
| /// Parse a region body into 'region'. |
| ParseResult parseRegionBody(Region ®ion); |
| |
| //===--------------------------------------------------------------------===// |
| // Block Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a new block into 'block'. |
| ParseResult parseBlock(Block *&block); |
| |
| /// Parse a list of operations into 'block'. |
| ParseResult parseBlockBody(Block *block); |
| |
| /// Parse a (possibly empty) list of block arguments. |
| ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, |
| Block *owner); |
| |
| /// Get the block with the specified name, creating it if it doesn't |
| /// already exist. The location specified is the point of use, which allows |
| /// us to diagnose references to blocks that are not defined precisely. |
| Block *getBlockNamed(StringRef name, SMLoc loc); |
| |
| /// Define the block with the specified name. Returns the Block* or nullptr in |
| /// the case of redefinition. |
| Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); |
| |
| private: |
| /// Returns the info for a block at the current scope for the given name. |
| std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { |
| return blocksByName.back()[name]; |
| } |
| |
| /// Insert a new forward reference to the given block. |
| void insertForwardRef(Block *block, SMLoc loc) { |
| forwardRef.back().try_emplace(block, loc); |
| } |
| |
| /// Erase any forward reference to the given block. |
| bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } |
| |
| /// Record that a definition was added at the current scope. |
| void recordDefinition(StringRef def); |
| |
| /// Get the value entry for the given SSA name. |
| SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); |
| |
| /// Create a forward reference placeholder value with the given location and |
| /// result type. |
| Value createForwardRefPlaceholder(SMLoc loc, Type type); |
| |
| /// Return true if this is a forward reference. |
| bool isForwardRefPlaceholder(Value value) { |
| return forwardRefPlaceholders.count(value); |
| } |
| |
| /// This struct represents an isolated SSA name scope. This scope may contain |
| /// other nested non-isolated scopes. These scopes are used for operations |
| /// that are known to be isolated to allow for reusing names within their |
| /// regions, even if those names are used above. |
| struct IsolatedSSANameScope { |
| /// Record that a definition was added at the current scope. |
| void recordDefinition(StringRef def) { |
| definitionsPerScope.back().insert(def); |
| } |
| |
| /// Push a nested name scope. |
| void pushSSANameScope() { definitionsPerScope.push_back({}); } |
| |
| /// Pop a nested name scope. |
| void popSSANameScope() { |
| for (auto &def : definitionsPerScope.pop_back_val()) |
| values.erase(def.getKey()); |
| } |
| |
| /// This keeps track of all of the SSA values we are tracking for each name |
| /// scope, indexed by their name. This has one entry per result number. |
| llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; |
| |
| /// This keeps track of all of the values defined by a specific name scope. |
| SmallVector<llvm::StringSet<>, 2> definitionsPerScope; |
| }; |
| |
| /// A list of isolated name scopes. |
| SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; |
| |
| /// This keeps track of the block names as well as the location of the first |
| /// reference for each nested name scope. This is used to diagnose invalid |
| /// block references and memorize them. |
| SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; |
| SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; |
| |
| /// These are all of the placeholders we've made along with the location of |
| /// their first reference, to allow checking for use of undefined values. |
| DenseMap<Value, SMLoc> forwardRefPlaceholders; |
| |
| /// The builder used when creating parsed operation instances. |
| OpBuilder opBuilder; |
| |
| /// The top level module operation. |
| ModuleOp moduleOp; |
| }; |
| } // end anonymous namespace |
| |
| OperationParser::~OperationParser() { |
| for (auto &fwd : forwardRefPlaceholders) { |
| // Drop all uses of undefined forward declared reference and destroy |
| // defining operation. |
| fwd.first.dropAllUses(); |
| fwd.first.getDefiningOp()->destroy(); |
| } |
| } |
| |
| /// After parsing is finished, this function must be called to see if there are |
| /// any remaining issues. |
| ParseResult OperationParser::finalize() { |
| // Check for any forward references that are left. If we find any, error |
| // out. |
| if (!forwardRefPlaceholders.empty()) { |
| SmallVector<std::pair<const char *, Value>, 4> errors; |
| // Iteration over the map isn't deterministic, so sort by source location. |
| for (auto entry : forwardRefPlaceholders) |
| errors.push_back({entry.second.getPointer(), entry.first}); |
| llvm::array_pod_sort(errors.begin(), errors.end()); |
| |
| for (auto entry : errors) { |
| auto loc = SMLoc::getFromPointer(entry.first); |
| emitError(loc, "use of undeclared SSA value name"); |
| } |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SSA Value Handling |
| //===----------------------------------------------------------------------===// |
| |
| void OperationParser::pushSSANameScope(bool isIsolated) { |
| blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); |
| forwardRef.push_back(DenseMap<Block *, SMLoc>()); |
| |
| // Push back a new name definition scope. |
| if (isIsolated) |
| isolatedNameScopes.push_back({}); |
| isolatedNameScopes.back().pushSSANameScope(); |
| } |
| |
| ParseResult OperationParser::popSSANameScope() { |
| auto forwardRefInCurrentScope = forwardRef.pop_back_val(); |
| |
| // Verify that all referenced blocks were defined. |
| if (!forwardRefInCurrentScope.empty()) { |
| SmallVector<std::pair<const char *, Block *>, 4> errors; |
| // Iteration over the map isn't deterministic, so sort by source location. |
| for (auto entry : forwardRefInCurrentScope) { |
| errors.push_back({entry.second.getPointer(), entry.first}); |
| // Add this block to the top-level region to allow for automatic cleanup. |
| moduleOp.getOperation()->getRegion(0).push_back(entry.first); |
| } |
| llvm::array_pod_sort(errors.begin(), errors.end()); |
| |
| for (auto entry : errors) { |
| auto loc = SMLoc::getFromPointer(entry.first); |
| emitError(loc, "reference to an undefined block"); |
| } |
| return failure(); |
| } |
| |
| // Pop the next nested namescope. If there is only one internal namescope, |
| // just pop the isolated scope. |
| auto ¤tNameScope = isolatedNameScopes.back(); |
| if (currentNameScope.definitionsPerScope.size() == 1) |
| isolatedNameScopes.pop_back(); |
| else |
| currentNameScope.popSSANameScope(); |
| |
| blocksByName.pop_back(); |
| return success(); |
| } |
| |
| /// Register a definition of a value with the symbol table. |
| ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { |
| auto &entries = getSSAValueEntry(useInfo.name); |
| |
| // Make sure there is a slot for this value. |
| if (entries.size() <= useInfo.number) |
| entries.resize(useInfo.number + 1); |
| |
| // If we already have an entry for this, check to see if it was a definition |
| // or a forward reference. |
| if (auto existing = entries[useInfo.number].first) { |
| if (!isForwardRefPlaceholder(existing)) { |
| return emitError(useInfo.loc) |
| .append("redefinition of SSA value '", useInfo.name, "'") |
| .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) |
| .append("previously defined here"); |
| } |
| |
| // If it was a forward reference, update everything that used it to use |
| // the actual definition instead, delete the forward ref, and remove it |
| // from our set of forward references we track. |
| existing.replaceAllUsesWith(value); |
| existing.getDefiningOp()->destroy(); |
| forwardRefPlaceholders.erase(existing); |
| } |
| |
| /// Record this definition for the current scope. |
| entries[useInfo.number] = {value, useInfo.loc}; |
| recordDefinition(useInfo.name); |
| return success(); |
| } |
| |
| /// Parse a (possibly empty) list of SSA operands. |
| /// |
| /// ssa-use-list ::= ssa-use (`,` ssa-use)* |
| /// ssa-use-list-opt ::= ssa-use-list? |
| /// |
| ParseResult |
| OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { |
| if (getToken().isNot(Token::percent_identifier)) |
| return success(); |
| return parseCommaSeparatedList([&]() -> ParseResult { |
| SSAUseInfo result; |
| if (parseSSAUse(result)) |
| return failure(); |
| results.push_back(result); |
| return success(); |
| }); |
| } |
| |
| /// Parse a SSA operand for an operation. |
| /// |
| /// ssa-use ::= ssa-id |
| /// |
| ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { |
| result.name = getTokenSpelling(); |
| result.number = 0; |
| result.loc = getToken().getLoc(); |
| if (parseToken(Token::percent_identifier, "expected SSA operand")) |
| return failure(); |
| |
| // If we have an attribute ID, it is a result number. |
| if (getToken().is(Token::hash_identifier)) { |
| if (auto value = getToken().getHashIdentifierNumber()) |
| result.number = value.getValue(); |
| else |
| return emitError("invalid SSA value result number"); |
| consumeToken(Token::hash_identifier); |
| } |
| |
| return success(); |
| } |
| |
| /// Given an unbound reference to an SSA value and its type, return the value |
| /// it specifies. This returns null on failure. |
| Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { |
| auto &entries = getSSAValueEntry(useInfo.name); |
| |
| // If we have already seen a value of this name, return it. |
| if (useInfo.number < entries.size() && entries[useInfo.number].first) { |
| auto result = entries[useInfo.number].first; |
| // Check that the type matches the other uses. |
| if (result.getType() == type) |
| return result; |
| |
| emitError(useInfo.loc, "use of value '") |
| .append(useInfo.name, |
| "' expects different type than prior uses: ", type, " vs ", |
| result.getType()) |
| .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) |
| .append("prior use here"); |
| return nullptr; |
| } |
| |
| // Make sure we have enough slots for this. |
| if (entries.size() <= useInfo.number) |
| entries.resize(useInfo.number + 1); |
| |
| // If the value has already been defined and this is an overly large result |
| // number, diagnose that. |
| if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) |
| return (emitError(useInfo.loc, "reference to invalid result number"), |
| nullptr); |
| |
| // Otherwise, this is a forward reference. Create a placeholder and remember |
| // that we did so. |
| auto result = createForwardRefPlaceholder(useInfo.loc, type); |
| entries[useInfo.number].first = result; |
| entries[useInfo.number].second = useInfo.loc; |
| return result; |
| } |
| |
| /// Parse an SSA use with an associated type. |
| /// |
| /// ssa-use-and-type ::= ssa-use `:` type |
| ParseResult OperationParser::parseSSADefOrUseAndType( |
| const std::function<ParseResult(SSAUseInfo, Type)> &action) { |
| SSAUseInfo useInfo; |
| if (parseSSAUse(useInfo) || |
| parseToken(Token::colon, "expected ':' and type for SSA operand")) |
| return failure(); |
| |
| auto type = parseType(); |
| if (!type) |
| return failure(); |
| |
| return action(useInfo, type); |
| } |
| |
| /// Parse a (possibly empty) list of SSA operands, followed by a colon, then |
| /// followed by a type list. |
| /// |
| /// ssa-use-and-type-list |
| /// ::= ssa-use-list ':' type-list-no-parens |
| /// |
| ParseResult OperationParser::parseOptionalSSAUseAndTypeList( |
| SmallVectorImpl<Value> &results) { |
| SmallVector<SSAUseInfo, 4> valueIDs; |
| if (parseOptionalSSAUseList(valueIDs)) |
| return failure(); |
| |
| // If there were no operands, then there is no colon or type lists. |
| if (valueIDs.empty()) |
| return success(); |
| |
| SmallVector<Type, 4> types; |
| if (parseToken(Token::colon, "expected ':' in operand list") || |
| parseTypeListNoParens(types)) |
| return failure(); |
| |
| if (valueIDs.size() != types.size()) |
| return emitError("expected ") |
| << valueIDs.size() << " types to match operand list"; |
| |
| results.reserve(valueIDs.size()); |
| for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { |
| if (auto value = resolveSSAUse(valueIDs[i], types[i])) |
| results.push_back(value); |
| else |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| /// Record that a definition was added at the current scope. |
| void OperationParser::recordDefinition(StringRef def) { |
| isolatedNameScopes.back().recordDefinition(def); |
| } |
| |
| /// Get the value entry for the given SSA name. |
| SmallVectorImpl<std::pair<Value, SMLoc>> & |
| OperationParser::getSSAValueEntry(StringRef name) { |
| return isolatedNameScopes.back().values[name]; |
| } |
| |
| /// Create and remember a new placeholder for a forward reference. |
| Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { |
| // Forward references are always created as operations, because we just need |
| // something with a def/use chain. |
| // |
| // We create these placeholders as having an empty name, which we know |
| // cannot be created through normal user input, allowing us to distinguish |
| // them. |
| auto name = OperationName("placeholder", getContext()); |
| auto *op = Operation::create( |
| getEncodedSourceLocation(loc), name, type, /*operands=*/{}, |
| /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, |
| /*resizableOperandList=*/false); |
| forwardRefPlaceholders[op->getResult(0)] = loc; |
| return op->getResult(0); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Operation Parsing |
| //===----------------------------------------------------------------------===// |
| |
| /// Parse an operation. |
| /// |
| /// operation ::= op-result-list? |
| /// (generic-operation | custom-operation) |
| /// trailing-location? |
| /// generic-operation ::= string-literal `(` ssa-use-list? `)` |
| /// successor-list? (`(` region-list `)`)? |
| /// attribute-dict? `:` function-type |
| /// custom-operation ::= bare-id custom-operation-format |
| /// op-result-list ::= op-result (`,` op-result)* `=` |
| /// op-result ::= ssa-id (`:` integer-literal) |
| /// |
| ParseResult OperationParser::parseOperation() { |
| auto loc = getToken().getLoc(); |
| SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs; |
| size_t numExpectedResults = 0; |
| if (getToken().is(Token::percent_identifier)) { |
| // Parse the group of result ids. |
| auto parseNextResult = [&]() -> ParseResult { |
| // Parse the next result id. |
| if (!getToken().is(Token::percent_identifier)) |
| return emitError("expected valid ssa identifier"); |
| |
| Token nameTok = getToken(); |
| consumeToken(Token::percent_identifier); |
| |
| // If the next token is a ':', we parse the expected result count. |
| size_t expectedSubResults = 1; |
| if (consumeIf(Token::colon)) { |
| // Check that the next token is an integer. |
| if (!getToken().is(Token::integer)) |
| return emitError("expected integer number of results"); |
| |
| // Check that number of results is > 0. |
| auto val = getToken().getUInt64IntegerValue(); |
| if (!val.hasValue() || val.getValue() < 1) |
| return emitError("expected named operation to have atleast 1 result"); |
| consumeToken(Token::integer); |
| expectedSubResults = *val; |
| } |
| |
| resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, |
| nameTok.getLoc()); |
| numExpectedResults += expectedSubResults; |
| return success(); |
| }; |
| if (parseCommaSeparatedList(parseNextResult)) |
| return failure(); |
| |
| if (parseToken(Token::equal, "expected '=' after SSA name")) |
| return failure(); |
| } |
| |
| Operation *op; |
| if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) |
| op = parseCustomOperation(); |
| else if (getToken().is(Token::string)) |
| op = parseGenericOperation(); |
| else |
| return emitError("expected operation name in quotes"); |
| |
| // If parsing of the basic operation failed, then this whole thing fails. |
| if (!op) |
| return failure(); |
| |
| // If the operation had a name, register it. |
| if (!resultIDs.empty()) { |
| if (op->getNumResults() == 0) |
| return emitError(loc, "cannot name an operation with no results"); |
| if (numExpectedResults != op->getNumResults()) |
| return emitError(loc, "operation defines ") |
| << op->getNumResults() << " results but was provided " |
| << numExpectedResults << " to bind"; |
| |
| // Add definitions for each of the result groups. |
| unsigned opResI = 0; |
| for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) { |
| for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { |
| if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, |
| op->getResult(opResI++))) |
| return failure(); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| /// Parse a single operation successor and its operand list. |
| /// |
| /// successor ::= block-id branch-use-list? |
| /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` |
| /// |
| ParseResult |
| OperationParser::parseSuccessorAndUseList(Block *&dest, |
| SmallVectorImpl<Value> &operands) { |
| // Verify branch is identifier and get the matching block. |
| if (!getToken().is(Token::caret_identifier)) |
| return emitError("expected block name"); |
| dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); |
| consumeToken(); |
| |
| // Handle optional arguments. |
| if (consumeIf(Token::l_paren) && |
| (parseOptionalSSAUseAndTypeList(operands) || |
| parseToken(Token::r_paren, "expected ')' to close argument list"))) { |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| /// Parse a comma-separated list of operation successors in brackets. |
| /// |
| /// successor-list ::= `[` successor (`,` successor )* `]` |
| /// |
| ParseResult OperationParser::parseSuccessors( |
| SmallVectorImpl<Block *> &destinations, |
| SmallVectorImpl<SmallVector<Value, 4>> &operands) { |
| if (parseToken(Token::l_square, "expected '['")) |
| return failure(); |
| |
| auto parseElt = [this, &destinations, &operands]() { |
| Block *dest; |
| SmallVector<Value, 4> destOperands; |
| auto res = parseSuccessorAndUseList(dest, destOperands); |
| destinations.push_back(dest); |
| operands.push_back(destOperands); |
| return res; |
| }; |
| return parseCommaSeparatedListUntil(Token::r_square, parseElt, |
| /*allowEmptyList=*/false); |
| } |
| |
| namespace { |
| // RAII-style guard for cleaning up the regions in the operation state before |
| // deleting them. Within the parser, regions may get deleted if parsing failed, |
| // and other errors may be present, in particular undominated uses. This makes |
| // sure such uses are deleted. |
| struct CleanupOpStateRegions { |
| ~CleanupOpStateRegions() { |
| SmallVector<Region *, 4> regionsToClean; |
| regionsToClean.reserve(state.regions.size()); |
| for (auto ®ion : state.regions) |
| if (region) |
| for (auto &block : *region) |
| block.dropAllDefinedValueUses(); |
| } |
| OperationState &state; |
| }; |
| } // namespace |
| |
| Operation *OperationParser::parseGenericOperation() { |
| // Get location information for the operation. |
| auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); |
| |
| auto name = getToken().getStringValue(); |
| if (name.empty()) |
| return (emitError("empty operation name is invalid"), nullptr); |
| if (name.find('\0') != StringRef::npos) |
| return (emitError("null character not allowed in operation name"), nullptr); |
| |
| consumeToken(Token::string); |
| |
| OperationState result(srcLocation, name); |
| |
| // Generic operations have a resizable operation list. |
| result.setOperandListToResizable(); |
| |
| // Parse the operand list. |
| SmallVector<SSAUseInfo, 8> operandInfos; |
| |
| if (parseToken(Token::l_paren, "expected '(' to start operand list") || |
| parseOptionalSSAUseList(operandInfos) || |
| parseToken(Token::r_paren, "expected ')' to end operand list")) { |
| return nullptr; |
| } |
| |
| // Parse the successor list but don't add successors to the result yet to |
| // avoid messing up with the argument order. |
| SmallVector<Block *, 2> successors; |
| SmallVector<SmallVector<Value, 4>, 2> successorOperands; |
| if (getToken().is(Token::l_square)) { |
| // Check if the operation is a known terminator. |
| const AbstractOperation *abstractOp = result.name.getAbstractOperation(); |
| if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) |
| return emitError("successors in non-terminator"), nullptr; |
| if (parseSuccessors(successors, successorOperands)) |
| return nullptr; |
| } |
| |
| // Parse the region list. |
| CleanupOpStateRegions guard{result}; |
| if (consumeIf(Token::l_paren)) { |
| do { |
| // Create temporary regions with the top level region as parent. |
| result.regions.emplace_back(new Region(moduleOp)); |
| if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) |
| return nullptr; |
| } while (consumeIf(Token::comma)); |
| if (parseToken(Token::r_paren, "expected ')' to end region list")) |
| return nullptr; |
| } |
| |
| if (getToken().is(Token::l_brace)) { |
| if (parseAttributeDict(result.attributes)) |
| return nullptr; |
| } |
| |
| if (parseToken(Token::colon, "expected ':' followed by operation type")) |
| return nullptr; |
| |
| auto typeLoc = getToken().getLoc(); |
| auto type = parseType(); |
| if (!type) |
| return nullptr; |
| auto fnType = type.dyn_cast<FunctionType>(); |
| if (!fnType) |
| return (emitError(typeLoc, "expected function type"), nullptr); |
| |
| result.addTypes(fnType.getResults()); |
| |
| // Check that we have the right number of types for the operands. |
| auto operandTypes = fnType.getInputs(); |
| if (operandTypes.size() != operandInfos.size()) { |
| auto plural = "s"[operandInfos.size() == 1]; |
| return (emitError(typeLoc, "expected ") |
| << operandInfos.size() << " operand type" << plural |
| << " but had " << operandTypes.size(), |
| nullptr); |
| } |
| |
| // Resolve all of the operands. |
| for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { |
| result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); |
| if (!result.operands.back()) |
| return nullptr; |
| } |
| |
| // Add the successors, and their operands after the proper operands. |
| for (auto succ : llvm::zip(successors, successorOperands)) { |
| Block *successor = std::get<0>(succ); |
| const SmallVector<Value, 4> &operands = std::get<1>(succ); |
| result.addSuccessor(successor, operands); |
| } |
| |
| // Parse a location if one is present. |
| if (parseOptionalTrailingLocation(result.location)) |
| return nullptr; |
| |
| return opBuilder.createOperation(result); |
| } |
| |
| Operation *OperationParser::parseGenericOperation(Block *insertBlock, |
| Block::iterator insertPt) { |
| OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); |
| opBuilder.setInsertionPoint(insertBlock, insertPt); |
| return parseGenericOperation(); |
| } |
| |
| namespace { |
| class CustomOpAsmParser : public OpAsmParser { |
| public: |
| CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, |
| OperationParser &parser) |
| : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} |
| |
| /// Parse an instance of the operation described by 'opDefinition' into the |
| /// provided operation state. |
| ParseResult parseOperation(OperationState &opState) { |
| if (opDefinition->parseAssembly(*this, opState)) |
| return failure(); |
| return success(); |
| } |
| |
| Operation *parseGenericOperation(Block *insertBlock, |
| Block::iterator insertPt) final { |
| return parser.parseGenericOperation(insertBlock, insertPt); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Utilities |
| //===--------------------------------------------------------------------===// |
| |
| /// Return if any errors were emitted during parsing. |
| bool didEmitError() const { return emittedError; } |
| |
| /// Emit a diagnostic at the specified location and return failure. |
| InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { |
| emittedError = true; |
| return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + |
| message); |
| } |
| |
| llvm::SMLoc getCurrentLocation() override { |
| return parser.getToken().getLoc(); |
| } |
| |
| Builder &getBuilder() const override { return parser.builder; } |
| |
| llvm::SMLoc getNameLoc() const override { return nameLoc; } |
| |
| //===--------------------------------------------------------------------===// |
| // Token Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a `->` token. |
| ParseResult parseArrow() override { |
| return parser.parseToken(Token::arrow, "expected '->'"); |
| } |
| |
| /// Parses a `->` if present. |
| ParseResult parseOptionalArrow() override { |
| return success(parser.consumeIf(Token::arrow)); |
| } |
| |
| /// Parse a `:` token. |
| ParseResult parseColon() override { |
| return parser.parseToken(Token::colon, "expected ':'"); |
| } |
| |
| /// Parse a `:` token if present. |
| ParseResult parseOptionalColon() override { |
| return success(parser.consumeIf(Token::colon)); |
| } |
| |
| /// Parse a `,` token. |
| ParseResult parseComma() override { |
| return parser.parseToken(Token::comma, "expected ','"); |
| } |
| |
| /// Parse a `,` token if present. |
| ParseResult parseOptionalComma() override { |
| return success(parser.consumeIf(Token::comma)); |
| } |
| |
| /// Parses a `...` if present. |
| ParseResult parseOptionalEllipsis() override { |
| return success(parser.consumeIf(Token::ellipsis)); |
| } |
| |
| /// Parse a `=` token. |
| ParseResult parseEqual() override { |
| return parser.parseToken(Token::equal, "expected '='"); |
| } |
| |
| /// Parse a '<' token. |
| ParseResult parseLess() override { |
| return parser.parseToken(Token::less, "expected '<'"); |
| } |
| |
| /// Parse a '>' token. |
| ParseResult parseGreater() override { |
| return parser.parseToken(Token::greater, "expected '>'"); |
| } |
| |
| /// Parse a `(` token. |
| ParseResult parseLParen() override { |
| return parser.parseToken(Token::l_paren, "expected '('"); |
| } |
| |
| /// Parses a '(' if present. |
| ParseResult parseOptionalLParen() override { |
| return success(parser.consumeIf(Token::l_paren)); |
| } |
| |
| /// Parse a `)` token. |
| ParseResult parseRParen() override { |
| return parser.parseToken(Token::r_paren, "expected ')'"); |
| } |
| |
| /// Parses a ')' if present. |
| ParseResult parseOptionalRParen() override { |
| return success(parser.consumeIf(Token::r_paren)); |
| } |
| |
| /// Parse a `[` token. |
| ParseResult parseLSquare() override { |
| return parser.parseToken(Token::l_square, "expected '['"); |
| } |
| |
| /// Parses a '[' if present. |
| ParseResult parseOptionalLSquare() override { |
| return success(parser.consumeIf(Token::l_square)); |
| } |
| |
| /// Parse a `]` token. |
| ParseResult parseRSquare() override { |
| return parser.parseToken(Token::r_square, "expected ']'"); |
| } |
| |
| /// Parses a ']' if present. |
| ParseResult parseOptionalRSquare() override { |
| return success(parser.consumeIf(Token::r_square)); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Attribute Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse an arbitrary attribute of a given type and return it in result. This |
| /// also adds the attribute to the specified attribute list with the specified |
| /// name. |
| ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, |
| SmallVectorImpl<NamedAttribute> &attrs) override { |
| result = parser.parseAttribute(type); |
| if (!result) |
| return failure(); |
| |
| attrs.push_back(parser.builder.getNamedAttr(attrName, result)); |
| return success(); |
| } |
| |
| /// Parse a named dictionary into 'result' if it is present. |
| ParseResult |
| parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { |
| if (parser.getToken().isNot(Token::l_brace)) |
| return success(); |
| return parser.parseAttributeDict(result); |
| } |
| |
| /// Parse a named dictionary into 'result' if the `attributes` keyword is |
| /// present. |
| ParseResult parseOptionalAttrDictWithKeyword( |
| SmallVectorImpl<NamedAttribute> &result) override { |
| if (failed(parseOptionalKeyword("attributes"))) |
| return success(); |
| return parser.parseAttributeDict(result); |
| } |
| |
| /// Parse an affine map instance into 'map'. |
| ParseResult parseAffineMap(AffineMap &map) override { |
| return parser.parseAffineMapReference(map); |
| } |
| |
| /// Parse an integer set instance into 'set'. |
| ParseResult printIntegerSet(IntegerSet &set) override { |
| return parser.parseIntegerSetReference(set); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Identifier Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Returns if the current token corresponds to a keyword. |
| bool isCurrentTokenAKeyword() const { |
| return parser.getToken().is(Token::bare_identifier) || |
| parser.getToken().isKeyword(); |
| } |
| |
| /// Parse the given keyword if present. |
| ParseResult parseOptionalKeyword(StringRef keyword) override { |
| // Check that the current token has the same spelling. |
| if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) |
| return failure(); |
| parser.consumeToken(); |
| return success(); |
| } |
| |
| /// Parse a keyword, if present, into 'keyword'. |
| ParseResult parseOptionalKeyword(StringRef *keyword) override { |
| // Check that the current token is a keyword. |
| if (!isCurrentTokenAKeyword()) |
| return failure(); |
| |
| *keyword = parser.getTokenSpelling(); |
| parser.consumeToken(); |
| return success(); |
| } |
| |
| /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
| /// string attribute named 'attrName'. |
| ParseResult |
| parseOptionalSymbolName(StringAttr &result, StringRef attrName, |
| SmallVectorImpl<NamedAttribute> &attrs) override { |
| Token atToken = parser.getToken(); |
| if (atToken.isNot(Token::at_identifier)) |
| return failure(); |
| |
| result = getBuilder().getStringAttr(extractSymbolReference(atToken)); |
| attrs.push_back(getBuilder().getNamedAttr(attrName, result)); |
| parser.consumeToken(); |
| return success(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Operand Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a single operand. |
| ParseResult parseOperand(OperandType &result) override { |
| OperationParser::SSAUseInfo useInfo; |
| if (parser.parseSSAUse(useInfo)) |
| return failure(); |
| |
| result = {useInfo.loc, useInfo.name, useInfo.number}; |
| return success(); |
| } |
| |
| /// Parse zero or more SSA comma-separated operand references with a specified |
| /// surrounding delimiter, and an optional required operand count. |
| ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, |
| int requiredOperandCount = -1, |
| Delimiter delimiter = Delimiter::None) override { |
| return parseOperandOrRegionArgList(result, /*isOperandList=*/true, |
| requiredOperandCount, delimiter); |
| } |
| |
| /// Parse zero or more SSA comma-separated operand or region arguments with |
| /// optional surrounding delimiter and required operand count. |
| ParseResult |
| parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, |
| bool isOperandList, int requiredOperandCount = -1, |
| Delimiter delimiter = Delimiter::None) { |
| auto startLoc = parser.getToken().getLoc(); |
| |
| // Handle delimiters. |
| switch (delimiter) { |
| case Delimiter::None: |
| // Don't check for the absence of a delimiter if the number of operands |
| // is unknown (and hence the operand list could be empty). |
| if (requiredOperandCount == -1) |
| break; |
| // Token already matches an identifier and so can't be a delimiter. |
| if (parser.getToken().is(Token::percent_identifier)) |
| break; |
| // Test against known delimiters. |
| if (parser.getToken().is(Token::l_paren) || |
| parser.getToken().is(Token::l_square)) |
| return emitError(startLoc, "unexpected delimiter"); |
| return emitError(startLoc, "invalid operand"); |
| case Delimiter::OptionalParen: |
| if (parser.getToken().isNot(Token::l_paren)) |
| return success(); |
| LLVM_FALLTHROUGH; |
| case Delimiter::Paren: |
| if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) |
| return failure(); |
| break; |
| case Delimiter::OptionalSquare: |
| if (parser.getToken().isNot(Token::l_square)) |
| return success(); |
| LLVM_FALLTHROUGH; |
| case Delimiter::Square: |
| if (parser.parseToken(Token::l_square, "expected '[' in operand list")) |
| return failure(); |
| break; |
| } |
| |
| // Check for zero operands. |
| if (parser.getToken().is(Token::percent_identifier)) { |
| do { |
| OperandType operandOrArg; |
| if (isOperandList ? parseOperand(operandOrArg) |
| : parseRegionArgument(operandOrArg)) |
| return failure(); |
| result.push_back(operandOrArg); |
| } while (parser.consumeIf(Token::comma)); |
| } |
| |
| // Handle delimiters. If we reach here, the optional delimiters were |
| // present, so we need to parse their closing one. |
| switch (delimiter) { |
| case Delimiter::None: |
| break; |
| case Delimiter::OptionalParen: |
| case Delimiter::Paren: |
| if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) |
| return failure(); |
| break; |
| case Delimiter::OptionalSquare: |
| case Delimiter::Square: |
| if (parser.parseToken(Token::r_square, "expected ']' in operand list")) |
| return failure(); |
| break; |
| } |
| |
| if (requiredOperandCount != -1 && |
| result.size() != static_cast<size_t>(requiredOperandCount)) |
| return emitError(startLoc, "expected ") |
| << requiredOperandCount << " operands"; |
| return success(); |
| } |
| |
| /// Parse zero or more trailing SSA comma-separated trailing operand |
| /// references with a specified surrounding delimiter, and an optional |
| /// required operand count. A leading comma is expected before the operands. |
| ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, |
| int requiredOperandCount, |
| Delimiter delimiter) override { |
| if (parser.getToken().is(Token::comma)) { |
| parseComma(); |
| return parseOperandList(result, requiredOperandCount, delimiter); |
| } |
| if (requiredOperandCount != -1) |
| return emitError(parser.getToken().getLoc(), "expected ") |
| << requiredOperandCount << " operands"; |
| return success(); |
| } |
| |
| /// Resolve an operand to an SSA value, emitting an error on failure. |
| ParseResult resolveOperand(const OperandType &operand, Type type, |
| SmallVectorImpl<Value> &result) override { |
| OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, |
| operand.location}; |
| if (auto value = parser.resolveSSAUse(operandInfo, type)) { |
| result.push_back(value); |
| return success(); |
| } |
| return failure(); |
| } |
| |
| /// Parse an AffineMap of SSA ids. |
| ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, |
| Attribute &mapAttr, StringRef attrName, |
| SmallVectorImpl<NamedAttribute> &attrs, |
| Delimiter delimiter) override { |
| SmallVector<OperandType, 2> dimOperands; |
| SmallVector<OperandType, 1> symOperands; |
| |
| auto parseElement = [&](bool isSymbol) -> ParseResult { |
| OperandType operand; |
| if (parseOperand(operand)) |
| return failure(); |
| if (isSymbol) |
| symOperands.push_back(operand); |
| else |
| dimOperands.push_back(operand); |
| return success(); |
| }; |
| |
| AffineMap map; |
| if (parser.parseAffineMapOfSSAIds(map, parseElement, delimiter)) |
| return failure(); |
| // Add AffineMap attribute. |
| if (map) { |
| mapAttr = AffineMapAttr::get(map); |
| attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); |
| } |
| |
| // Add dim operands before symbol operands in 'operands'. |
| operands.assign(dimOperands.begin(), dimOperands.end()); |
| operands.append(symOperands.begin(), symOperands.end()); |
| return success(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Region Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a region that takes `arguments` of `argTypes` types. This |
| /// effectively defines the SSA values of `arguments` and assigns their type. |
| ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, |
| ArrayRef<Type> argTypes, |
| bool enableNameShadowing) override { |
| assert(arguments.size() == argTypes.size() && |
| "mismatching number of arguments and types"); |
| |
| SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> |
| regionArguments; |
| for (auto pair : llvm::zip(arguments, argTypes)) { |
| const OperandType &operand = std::get<0>(pair); |
| Type type = std::get<1>(pair); |
| OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, |
| operand.location}; |
| regionArguments.emplace_back(operandInfo, type); |
| } |
| |
| // Try to parse the region. |
| assert((!enableNameShadowing || |
| opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && |
| "name shadowing is only allowed on isolated regions"); |
| if (parser.parseRegion(region, regionArguments, enableNameShadowing)) |
| return failure(); |
| return success(); |
| } |
| |
| /// Parses a region if present. |
| ParseResult parseOptionalRegion(Region ®ion, |
| ArrayRef<OperandType> arguments, |
| ArrayRef<Type> argTypes, |
| bool enableNameShadowing) override { |
| if (parser.getToken().isNot(Token::l_brace)) |
| return success(); |
| return parseRegion(region, arguments, argTypes, enableNameShadowing); |
| } |
| |
| /// Parse a region argument. The type of the argument will be resolved later |
| /// by a call to `parseRegion`. |
| ParseResult parseRegionArgument(OperandType &argument) override { |
| return parseOperand(argument); |
| } |
| |
| /// Parse a region argument if present. |
| ParseResult parseOptionalRegionArgument(OperandType &argument) override { |
| if (parser.getToken().isNot(Token::percent_identifier)) |
| return success(); |
| return parseRegionArgument(argument); |
| } |
| |
| ParseResult |
| parseRegionArgumentList(SmallVectorImpl<OperandType> &result, |
| int requiredOperandCount = -1, |
| Delimiter delimiter = Delimiter::None) override { |
| return parseOperandOrRegionArgList(result, /*isOperandList=*/false, |
| requiredOperandCount, delimiter); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Successor Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a single operation successor and its operand list. |
| ParseResult |
| parseSuccessorAndUseList(Block *&dest, |
| SmallVectorImpl<Value> &operands) override { |
| return parser.parseSuccessorAndUseList(dest, operands); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Type Parsing |
| //===--------------------------------------------------------------------===// |
| |
| /// Parse a type. |
| ParseResult parseType(Type &result) override { |
| return failure(!(result = parser.parseType())); |
| } |
| |
| /// Parse an arrow followed by a type list. |
| ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override { |
| if (parseArrow() || parser.parseFunctionResultTypes(result)) |
| return failure(); |
| return success(); |
| } |
| |
| /// Parse an optional arrow followed by a type list. |
| ParseResult |
| parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { |
| if (!parser.consumeIf(Token::arrow)) |
| return success(); |
| return parser.parseFunctionResultTypes(result); |
| } |
| |
| /// Parse a colon followed by a type. |
| ParseResult parseColonType(Type &result) override { |
| return failure(parser.parseToken(Token::colon, "expected ':'") || |
| !(result = parser.parseType())); |
| } |
| |
| /// Parse a colon followed by a type list, which must have at least one type. |
| ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { |
| if (parser.parseToken(Token::colon, "expected ':'")) |
| return failure(); |
| return parser.parseTypeListNoParens(result); |
| } |
| |
| /// Parse an optional colon followed by a type list, which if present must |
| /// have at least one type. |
| ParseResult |
| parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { |
| if (!parser.consumeIf(Token::colon)) |
| return success(); |
| return parser.parseTypeListNoParens(result); |
| } |
| |
| /// Parse a list of assignments of the form |
| /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). |
| /// The list must contain at least one entry |
| ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, |
| SmallVectorImpl<OperandType> &rhs) { |
| auto parseElt = [&]() -> ParseResult { |
| OperandType regionArg, operand; |
| Type type; |
| if (parseRegionArgument(regionArg) || parseEqual() || |
| parseOperand(operand)) |
| return failure(); |
| lhs.push_back(regionArg); |
| rhs.push_back(operand); |
| return success(); |
| }; |
| if (parseLParen()) |
| return failure(); |
| return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); |
| } |
| |
| private: |
| /// The source location of the operation name. |
| SMLoc nameLoc; |
| |
| /// The abstract information of the operation. |
| const AbstractOperation *opDefinition; |
| |
| /// The main operation parser. |
| OperationParser &parser; |
| |
| /// A flag that indicates if any errors were emitted during parsing. |
| bool emittedError = false; |
| }; |
| } // end anonymous namespace. |
| |
| Operation *OperationParser::parseCustomOperation() { |
| auto opLoc = getToken().getLoc(); |
| auto opName = getTokenSpelling(); |
| |
| auto *opDefinition = AbstractOperation::lookup(opName, getContext()); |
| if (!opDefinition && !opName.contains('.')) { |
| // If the operation name has no namespace prefix we treat it as a standard |
| // operation and prefix it with "std". |
| // TODO: Would it be better to just build a mapping of the registered |
| // operations in the standard dialect? |
| opDefinition = |
| AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); |
| } |
| |
| if (!opDefinition) { |
| emitError(opLoc) << "custom op '" << opName << "' is unknown"; |
| return nullptr; |
| } |
| |
| consumeToken(); |
| |
| // If the custom op parser crashes, produce some indication to help |
| // debugging. |
| std::string opNameStr = opName.str(); |
| llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", |
| opNameStr.c_str()); |
| |
| // Get location information for the operation. |
| auto srcLocation = getEncodedSourceLocation(opLoc); |
| |
| // Have the op implementation take a crack and parsing this. |
| OperationState opState(srcLocation, opDefinition->name); |
| CleanupOpStateRegions guard{opState}; |
| CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); |
| if (opAsmParser.parseOperation(opState)) |
| return nullptr; |
| |
| // If it emitted an error, we failed. |
| if (opAsmParser.didEmitError()) |
| return nullptr; |
| |
| // Parse a location if one is present. |
| if (parseOptionalTrailingLocation(opState.location)) |
| return nullptr; |
| |
| // Otherwise, we succeeded. Use the state it parsed as our op information. |
| return opBuilder.createOperation(opState); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Region Parsing |
| //===----------------------------------------------------------------------===// |
| |
| /// Region. |
| /// |
| /// region ::= '{' region-body |
| /// |
| ParseResult OperationParser::parseRegion( |
| Region ®ion, |
| ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, |
| bool isIsolatedNameScope) { |
| // Parse the '{'. |
| if (parseToken(Token::l_brace, "expected '{' to begin a region")) |
| return failure(); |
| |
| // Check for an empty region. |
| if (entryArguments.empty() && consumeIf(Token::r_brace)) |
| return success(); |
| auto currentPt = opBuilder.saveInsertionPoint(); |
| |
| // Push a new named value scope. |
| pushSSANameScope(isIsolatedNameScope); |
| |
| // Parse the first block directly to allow for it to be unnamed. |
| Block *block = new Block(); |
| |
| // Add arguments to the entry block. |
| if (!entryArguments.empty()) { |
| for (auto &placeholderArgPair : entryArguments) { |
| auto &argInfo = placeholderArgPair.first; |
| // Ensure that the argument was not already defined. |
| if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { |
| return emitError(argInfo.loc, "region entry argument '" + argInfo.name + |
| "' is already in use") |
| .attachNote(getEncodedSourceLocation(*defLoc)) |
| << "previously referenced here"; |
| } |
| if (addDefinition(placeholderArgPair.first, |
| block->addArgument(placeholderArgPair.second))) { |
| delete block; |
| return failure(); |
| } |
| } |
| |
| // If we had named arguments, then don't allow a block name. |
| if (getToken().is(Token::caret_identifier)) |
| return emitError("invalid block name in region with named arguments"); |
| } |
| |
| if (parseBlock(block)) { |
| delete block; |
| return failure(); |
| } |
| |
| // Verify that no other arguments were parsed. |
| if (!entryArguments.empty() && |
| block->getNumArguments() > entryArguments.size()) { |
| delete block; |
| return emitError("entry block arguments were already defined"); |
| } |
| |
| // Parse the rest of the region. |
| region.push_back(block); |
| if (parseRegionBody(region)) |
| return failure(); |
| |
| // Pop the SSA value scope for this region. |
| if (popSSANameScope()) |
| return failure(); |
| |
| // Reset the original insertion point. |
| opBuilder.restoreInsertionPoint(currentPt); |
| return success(); |
| } |
| |
| /// Region. |
| /// |
| /// region-body ::= block* '}' |
| /// |
| ParseResult OperationParser::parseRegionBody(Region ®ion) { |
| // Parse the list of blocks. |
| while (!consumeIf(Token::r_brace)) { |
| Block *newBlock = nullptr; |
| if (parseBlock(newBlock)) |
| return failure(); |
| region.push_back(newBlock); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Block Parsing |
| //===----------------------------------------------------------------------===// |
| |
| /// Block declaration. |
| /// |
| /// block ::= block-label? operation* |
| /// block-label ::= block-id block-arg-list? `:` |
| /// block-id ::= caret-id |
| /// block-arg-list ::= `(` ssa-id-and-type-list? `)` |
| /// |
| ParseResult OperationParser::parseBlock(Block *&block) { |
| // The first block of a region may already exist, if it does the caret |
| // identifier is optional. |
| if (block && getToken().isNot(Token::caret_identifier)) |
| return parseBlockBody(block); |
| |
| SMLoc nameLoc = getToken().getLoc(); |
| auto name = getTokenSpelling(); |
| if (parseToken(Token::caret_identifier, "expected block name")) |
| return failure(); |
| |
| block = defineBlockNamed(name, nameLoc, block); |
| |
| // Fail if the block was already defined. |
| if (!block) |
| return emitError(nameLoc, "redefinition of block '") << name << "'"; |
| |
| // If an argument list is present, parse it. |
| if (consumeIf(Token::l_paren)) { |
| SmallVector<BlockArgument, 8> bbArgs; |
| if (parseOptionalBlockArgList(bbArgs, block) || |
| parseToken(Token::r_paren, "expected ')' to end argument list")) |
| return failure(); |
| } |
| |
| if (parseToken(Token::colon, "expected ':' after block name")) |
| return failure(); |
| |
| return parseBlockBody(block); |
| } |
| |
| ParseResult OperationParser::parseBlockBody(Block *block) { |
| // Set the insertion point to the end of the block to parse. |
| opBuilder.setInsertionPointToEnd(block); |
| |
| // Parse the list of operations that make up the body of the block. |
| while (getToken().isNot(Token::caret_identifier, Token::r_brace)) |
| if (parseOperation()) |
| return failure(); |
| |
| return success(); |
| } |
| |
| /// Get the block with the specified name, creating it if it doesn't already |
| /// exist. The location specified is the point of use, which allows |
| /// us to diagnose references to blocks that are not defined precisely. |
| Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { |
| auto &blockAndLoc = getBlockInfoByName(name); |
| if (!blockAndLoc.first) { |
| blockAndLoc = {new Block(), loc}; |
| insertForwardRef(blockAndLoc.first, loc); |
| } |
| |
| return blockAndLoc.first; |
| } |
| |
| /// Define the block with the specified name. Returns the Block* or nullptr in |
| /// the case of redefinition. |
| Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, |
| Block *existing) { |
| auto &blockAndLoc = getBlockInfoByName(name); |
| if (!blockAndLoc.first) { |
| // If the caller provided a block, use it. Otherwise create a new one. |
| if (!existing) |
| existing = new Block(); |
| blockAndLoc.first = existing; |
| blockAndLoc.second = loc; |
| return blockAndLoc.first; |
| } |
| |
| // Forward declarations are removed once defined, so if we are defining a |
| // existing block and it is not a forward declaration, then it is a |
| // redeclaration. |
| if (!eraseForwardRef(blockAndLoc.first)) |
| return nullptr; |
| return blockAndLoc.first; |
| } |
| |
| /// Parse a (possibly empty) list of SSA operands with types as block arguments. |
| /// |
| /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* |
| /// |
| ParseResult OperationParser::parseOptionalBlockArgList( |
| SmallVectorImpl<BlockArgument> &results, Block *owner) { |
| if (getToken().is(Token::r_brace)) |
| return success(); |
| |
| // If the block already has arguments, then we're handling the entry block. |
| // Parse and register the names for the arguments, but do not add them. |
| bool definingExistingArgs = owner->getNumArguments() != 0; |
| unsigned nextArgument = 0; |
| |
| return parseCommaSeparatedList([&]() -> ParseResult { |
| return parseSSADefOrUseAndType( |
| [&](SSAUseInfo useInfo, Type type) -> ParseResult { |
| // If this block did not have existing arguments, define a new one. |
| if (!definingExistingArgs) |
| return addDefinition(useInfo, owner->addArgument(type)); |
| |
| // Otherwise, ensure that this argument has already been created. |
| if (nextArgument >= owner->getNumArguments()) |
| return emitError("too many arguments specified in argument list"); |
| |
| // Finally, make sure the existing argument has the correct type. |
| auto arg = owner->getArgument(nextArgument++); |
| if (arg.getType() != type) |
| return emitError("argument and block argument type mismatch"); |
| return addDefinition(useInfo, arg); |
| }); |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Top-level entity parsing. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This parser handles entities that are only valid at the top level of the |
| /// file. |
| class ModuleParser : public Parser { |
| public: |
| explicit ModuleParser(ParserState &state) : Parser(state) {} |
| |
| ParseResult parseModule(ModuleOp module); |
| |
| private: |
| /// Parse an attribute alias declaration. |
| ParseResult parseAttributeAliasDef(); |
| |
| /// Parse an attribute alias declaration. |
| ParseResult parseTypeAliasDef(); |
| }; |
| } // end anonymous namespace |
| |
| /// Parses an attribute alias declaration. |
| /// |
| /// attribute-alias-def ::= '#' alias-name `=` attribute-value |
| /// |
| ParseResult ModuleParser::parseAttributeAliasDef() { |
| assert(getToken().is(Token::hash_identifier)); |
| StringRef aliasName = getTokenSpelling().drop_front(); |
| |
| // Check for redefinitions. |
| if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) |
| return emitError("redefinition of attribute alias id '" + aliasName + "'"); |
| |
| // Make sure this isn't invading the dialect attribute namespace. |
| if (aliasName.contains('.')) |
| return emitError("attribute names with a '.' are reserved for " |
| "dialect-defined names"); |
| |
| consumeToken(Token::hash_identifier); |
| |
| // Parse the '='. |
| if (parseToken(Token::equal, "expected '=' in attribute alias definition")) |
| return failure(); |
| |
| // Parse the attribute value. |
| Attribute attr = parseAttribute(); |
| if (!attr) |
| return failure(); |
| |
| getState().symbols.attributeAliasDefinitions[aliasName] = attr; |
| return success(); |
| } |
| |
| /// Parse a type alias declaration. |
| /// |
| /// type-alias-def ::= '!' alias-name `=` 'type' type |
| /// |
| ParseResult ModuleParser::parseTypeAliasDef() { |
| assert(getToken().is(Token::exclamation_identifier)); |
| StringRef aliasName = getTokenSpelling().drop_front(); |
| |
| // Check for redefinitions. |
| if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) |
| return emitError("redefinition of type alias id '" + aliasName + "'"); |
| |
| // Make sure this isn't invading the dialect type namespace. |
| if (aliasName.contains('.')) |
| return emitError("type names with a '.' are reserved for " |
| "dialect-defined names"); |
| |
| consumeToken(Token::exclamation_identifier); |
| |
| // Parse the '=' and 'type'. |
| if (parseToken(Token::equal, "expected '=' in type alias definition") || |
| parseToken(Token::kw_type, "expected 'type' in type alias definition")) |
| return failure(); |
| |
| // Parse the type. |
| Type aliasedType = parseType(); |
| if (!aliasedType) |
| return failure(); |
| |
| // Register this alias with the parser state. |
| getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); |
| return success(); |
| } |
| |
| /// This is the top-level module parser. |
| ParseResult ModuleParser::parseModule(ModuleOp module) { |
| OperationParser opParser(getState(), module); |
| |
| // Module itself is a name scope. |
| opParser.pushSSANameScope(/*isIsolated=*/true); |
| |
| while (true) { |
| switch (getToken().getKind()) { |
| default: |
| // Parse a top-level operation. |
| if (opParser.parseOperation()) |
| return failure(); |
| break; |
| |
| // If we got to the end of the file, then we're done. |
| case Token::eof: { |
| if (opParser.finalize()) |
| return failure(); |
| |
| // Handle the case where the top level module was explicitly defined. |
| auto &bodyBlocks = module.getBodyRegion().getBlocks(); |
| auto &operations = bodyBlocks.front().getOperations(); |
| assert(!operations.empty() && "expected a valid module terminator"); |
| |
| // Check that the first operation is a module, and it is the only |
| // non-terminator operation. |
| ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); |
| if (nested && std::next(operations.begin(), 2) == operations.end()) { |
| // Merge the data of the nested module operation into 'module'. |
| module.setLoc(nested.getLoc()); |
| module.setAttrs(nested.getOperation()->getAttrList()); |
| bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); |
| |
| // Erase the original module body. |
| bodyBlocks.pop_front(); |
| } |
| |
| return opParser.popSSANameScope(); |
| } |
| |
| // If we got an error token, then the lexer already emitted an error, just |
| // stop. Someday we could introduce error recovery if there was demand |
| // for it. |
| case Token::error: |
| return failure(); |
| |
| // Parse an attribute alias. |
| case Token::hash_identifier: |
| if (parseAttributeAliasDef()) |
| return failure(); |
| break; |
| |
| // Parse a type alias. |
| case Token::exclamation_identifier: |
| if (parseTypeAliasDef()) |
| return failure(); |
| break; |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| |
| /// This parses the file specified by the indicated SourceMgr and returns an |
| /// MLIR module if it was valid. If not, it emits diagnostics and returns |
| /// null. |
| OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, |
| MLIRContext *context) { |
| auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); |
| |
| // This is the result module we are parsing into. |
| OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( |
| sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); |
| |
| SymbolState aliasState; |
| ParserState state(sourceMgr, context, aliasState); |
| if (ModuleParser(state).parseModule(*module)) |
| return nullptr; |
| |
| // Make sure the parse module has no other structural problems detected by |
| // the verifier. |
| if (failed(verify(*module))) |
| return nullptr; |
| |
| return module; |
| } |
| |
| /// This parses the file specified by the indicated filename and returns an |
| /// MLIR module if it was valid. If not, the error message is emitted through |
| /// the error handler registered in the context, and a null pointer is returned. |
| OwningModuleRef mlir::parseSourceFile(StringRef filename, |
| MLIRContext *context) { |
| llvm::SourceMgr sourceMgr; |
| return parseSourceFile(filename, sourceMgr, context); |
| } |
| |
| /// This parses the file specified by the indicated filename using the provided |
| /// SourceMgr and returns an MLIR module if it was valid. If not, the error |
| /// message is emitted through the error handler registered in the context, and |
| /// a null pointer is returned. |
| OwningModuleRef mlir::parseSourceFile(StringRef filename, |
| llvm::SourceMgr &sourceMgr, |
| MLIRContext *context) { |
| if (sourceMgr.getNumBuffers() != 0) { |
| // TODO(b/136086478): Extend to support multiple buffers. |
| emitError(mlir::UnknownLoc::get(context), |
| "only main buffer parsed at the moment"); |
| return nullptr; |
| } |
| auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); |
| if (std::error_code error = file_or_err.getError()) { |
| emitError(mlir::UnknownLoc::get(context), |
| "could not open input file " + filename); |
| return nullptr; |
| } |
| |
| // Load the MLIR module. |
| sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); |
| return parseSourceFile(sourceMgr, context); |
| } |
| |
| /// This parses the program string to a MLIR module if it was valid. If not, |
| /// it emits diagnostics and returns null. |
| OwningModuleRef mlir::parseSourceString(StringRef moduleStr, |
| MLIRContext *context) { |
| auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); |
| if (!memBuffer) |
| return nullptr; |
| |
| SourceMgr sourceMgr; |
| sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); |
| return parseSourceFile(sourceMgr, context); |
| } |
| |
| /// Parses a symbol, of type 'T', and returns it if parsing was successful. If |
| /// parsing failed, nullptr is returned. The number of bytes read from the input |
| /// string is returned in 'numRead'. |
| template <typename T, typename ParserFn> |
| static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, |
| ParserFn &&parserFn) { |
| SymbolState aliasState; |
| return parseSymbol<T>( |
| inputStr, context, aliasState, |
| [&](Parser &parser) { |
| SourceMgrDiagnosticHandler handler( |
| const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), |
| parser.getContext()); |
| return parserFn(parser); |
| }, |
| &numRead); |
| } |
| |
| Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { |
| size_t numRead = 0; |
| return parseAttribute(attrStr, context, numRead); |
| } |
| Attribute mlir::parseAttribute(StringRef attrStr, Type type) { |
| size_t numRead = 0; |
| return parseAttribute(attrStr, type, numRead); |
| } |
| |
| Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, |
| size_t &numRead) { |
| return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { |
| return parser.parseAttribute(); |
| }); |
| } |
| Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { |
| return parseSymbol<Attribute>( |
| attrStr, type.getContext(), numRead, |
| [type](Parser &parser) { return parser.parseAttribute(type); }); |
| } |
| |
| Type mlir::parseType(StringRef typeStr, MLIRContext *context) { |
| size_t numRead = 0; |
| return parseType(typeStr, context, numRead); |
| } |
| |
| Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { |
| return parseSymbol<Type>(typeStr, context, numRead, |
| [](Parser &parser) { return parser.parseType(); }); |
| } |