| //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// |
| // |
| // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // OpDefinitionsGen uses the description of operations to generate C++ |
| // definitions for ops. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Support/STLExtras.h" |
| #include "mlir/TableGen/Format.h" |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/ODSDialectHook.h" |
| #include "mlir/TableGen/OpClass.h" |
| #include "mlir/TableGen/OpInterfaces.h" |
| #include "mlir/TableGen/OpTrait.h" |
| #include "mlir/TableGen/Operator.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/Support/ManagedStatic.h" |
| #include "llvm/Support/Signals.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Record.h" |
| #include "llvm/TableGen/TableGenBackend.h" |
| |
| #define DEBUG_TYPE "mlir-tblgen-opdefgen" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| |
| using llvm::CodeInit; |
| using llvm::DefInit; |
| using llvm::formatv; |
| using llvm::Init; |
| using llvm::ListInit; |
| using llvm::Record; |
| using llvm::RecordKeeper; |
| using llvm::StringInit; |
| |
| //===----------------------------------------------------------------------===// |
| // Dialect hook registration |
| //===----------------------------------------------------------------------===// |
| |
| static llvm::ManagedStatic<llvm::StringMap<DialectEmitFunction>> dialectHooks; |
| |
| ODSDialectHookRegistration::ODSDialectHookRegistration( |
| StringRef dialectName, DialectEmitFunction emitFn) { |
| bool inserted = dialectHooks->try_emplace(dialectName, emitFn).second; |
| assert(inserted && "Multiple ODS hooks for the same dialect!"); |
| (void)inserted; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Static string definitions |
| //===----------------------------------------------------------------------===// |
| |
| static const char *const tblgenNamePrefix = "tblgen_"; |
| static const char *const generatedArgName = "tblgen_arg"; |
| static const char *const builderOpState = "tblgen_state"; |
| |
| // The logic to calculate the actual value range for a declared operand/result |
| // of an op with variadic operands/results. Note that this logic is not for |
| // general use; it assumes all variadic operands/results must have the same |
| // number of values. |
| // |
| // {0}: The list of whether each declared operand/result is variadic. |
| // {1}: The total number of non-variadic operands/results. |
| // {2}: The total number of variadic operands/results. |
| // {3}: The total number of actual values. |
| // {4}: The begin iterator of the actual values. |
| // {5}: "operand" or "result". |
| const char *sameVariadicSizeValueRangeCalcCode = R"( |
| bool isVariadic[] = {{{0}}; |
| int prevVariadicCount = 0; |
| for (unsigned i = 0; i < index; ++i) |
| if (isVariadic[i]) ++prevVariadicCount; |
| |
| // Calculate how many dynamic values a static variadic {5} corresponds to. |
| // This assumes all static variadic {5}s have the same dynamic value count. |
| int variadicSize = ({3} - {1}) / {2}; |
| // `index` passed in as the parameter is the static index which counts each |
| // {5} (variadic or not) as size 1. So here for each previous static variadic |
| // {5}, we need to offset by (variadicSize - 1) to get where the dynamic |
| // value pack for this static {5} starts. |
| int offset = index + (variadicSize - 1) * prevVariadicCount; |
| int size = isVariadic[index] ? variadicSize : 1; |
| |
| return {{std::next({4}, offset), std::next({4}, offset + size)}; |
| )"; |
| |
| // The logic to calculate the actual value range for a declared operand/result |
| // of an op with variadic operands/results. Note that this logic is assumes |
| // the op has an attribute specifying the size of each operand/result segment |
| // (variadic or not). |
| // |
| // {0}: The name of the attribute specifying the segment sizes. |
| // {1}: The begin iterator of the actual values. |
| const char *attrSizedSegmentValueRangeCalcCode = R"( |
| auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}"); |
| unsigned start = 0; |
| for (unsigned i = 0; i < index; ++i) |
| start += (*(sizeAttr.begin() + i)).getZExtValue(); |
| unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue(); |
| return {{std::next({1}, start), std::next({1}, end)}; |
| )"; |
| |
| static const char *const opCommentHeader = R"( |
| //===----------------------------------------------------------------------===// |
| // {0} {1} |
| //===----------------------------------------------------------------------===// |
| |
| )"; |
| |
| //===----------------------------------------------------------------------===// |
| // Utility structs and functions |
| //===----------------------------------------------------------------------===// |
| |
| // Returns whether the record has a value of the given name that can be returned |
| // via getValueAsString. |
| static inline bool hasStringAttribute(const Record &record, |
| StringRef fieldName) { |
| auto valueInit = record.getValueInit(fieldName); |
| return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit); |
| } |
| |
| static std::string getArgumentName(const Operator &op, int index) { |
| const auto &operand = op.getOperand(index); |
| if (!operand.name.empty()) |
| return operand.name; |
| else |
| return formatv("{0}_{1}", generatedArgName, index); |
| } |
| |
| // Returns true if we can use unwrapped value for the given `attr` in builders. |
| static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { |
| return attr.getReturnType() != attr.getStorageType() && |
| // We need to wrap the raw value into an attribute in the builder impl |
| // so we need to make sure that the attribute specifies how to do that. |
| !attr.getConstBuilderTemplate().empty(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Op emitter |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Simple RAII helper for defining ifdef-undef-endif scopes. |
| class IfDefScope { |
| public: |
| IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) { |
| os << "#ifdef " << name << "\n" |
| << "#undef " << name << "\n\n"; |
| } |
| |
| ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } |
| |
| private: |
| StringRef name; |
| raw_ostream &os; |
| }; |
| } // end anonymous namespace |
| |
| namespace { |
| // Helper class to emit a record into the given output stream. |
| class OpEmitter { |
| public: |
| static void emitDecl(const Operator &op, raw_ostream &os); |
| static void emitDef(const Operator &op, raw_ostream &os); |
| |
| private: |
| OpEmitter(const Operator &op); |
| |
| void emitDecl(raw_ostream &os); |
| void emitDef(raw_ostream &os); |
| |
| // Generates the OpAsmOpInterface for this operation if possible. |
| void genOpAsmInterface(); |
| |
| // Generates the `getOperationName` method for this op. |
| void genOpNameGetter(); |
| |
| // Generates getters for the attributes. |
| void genAttrGetters(); |
| |
| // Generates getters for named operands. |
| void genNamedOperandGetters(); |
| |
| // Generates getters for named results. |
| void genNamedResultGetters(); |
| |
| // Generates getters for named regions. |
| void genNamedRegionGetters(); |
| |
| // Generates builder methods for the operation. |
| void genBuilder(); |
| |
| // Generates the build() method that takes each operand/attribute |
| // as a stand-alone parameter. |
| void genSeparateArgParamBuilder(); |
| |
| // Generates the build() method that takes each operand/attribute as a |
| // stand-alone parameter. The generated build() method uses first operand's |
| // type as all results' types. |
| void genUseOperandAsResultTypeSeparateParamBuilder(); |
| |
| // Generates the build() method that takes all operands/attributes |
| // collectively as one parameter. The generated build() method uses first |
| // operand's type as all results' types. |
| void genUseOperandAsResultTypeCollectiveParamBuilder(); |
| |
| // Generates the build() method that takes aggregate operands/attributes |
| // parameters. This build() method uses inferred types as result types. |
| // Requires: The type needs to be inferable via InferTypeOpInterface. |
| void genInferedTypeCollectiveParamBuilder(); |
| |
| // Generates the build() method that takes each operand/attribute as a |
| // stand-alone parameter. The generated build() method uses first attribute's |
| // type as all result's types. |
| void genUseAttrAsResultTypeBuilder(); |
| |
| // Generates the build() method that takes all result types collectively as |
| // one parameter. Similarly for operands and attributes. |
| void genCollectiveParamBuilder(); |
| |
| // The kind of parameter to generate for result types in builders. |
| enum class TypeParamKind { |
| None, // No result type in parameter list. |
| Separate, // A separate parameter for each result type. |
| Collective, // An ArrayRef<Type> for all result types. |
| }; |
| |
| // The kind of parameter to generate for attributes in builders. |
| enum class AttrParamKind { |
| WrappedAttr, // A wrapped MLIR Attribute instance. |
| UnwrappedValue, // A raw value without MLIR Attribute wrapper. |
| }; |
| |
| // Builds the parameter list for build() method of this op. This method writes |
| // to `paramList` the comma-separated parameter list and updates |
| // `resultTypeNames` with the names for parameters for specifying result |
| // types. The given `typeParamKind` and `attrParamKind` controls how result |
| // types and attributes are placed in the parameter list. |
| void buildParamList(std::string ¶mList, |
| SmallVectorImpl<std::string> &resultTypeNames, |
| TypeParamKind typeParamKind, |
| AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); |
| |
| // Adds op arguments and regions into operation state for build() methods. |
| void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, |
| bool isRawValueAttr = false); |
| |
| // Generates canonicalizer declaration for the operation. |
| void genCanonicalizerDecls(); |
| |
| // Generates the folder declaration for the operation. |
| void genFolderDecls(); |
| |
| // Generates the parser for the operation. |
| void genParser(); |
| |
| // Generates the printer for the operation. |
| void genPrinter(); |
| |
| // Generates verify method for the operation. |
| void genVerifier(); |
| |
| // Generates verify statements for operands and results in the operation. |
| // The generated code will be attached to `body`. |
| void genOperandResultVerifier(OpMethodBody &body, |
| Operator::value_range values, |
| StringRef valueKind); |
| |
| // Generates verify statements for regions in the operation. |
| // The generated code will be attached to `body`. |
| void genRegionVerifier(OpMethodBody &body); |
| |
| // Generates the traits used by the object. |
| void genTraits(); |
| |
| // Generate the OpInterface methods. |
| void genOpInterfaceMethods(); |
| |
| private: |
| // The TableGen record for this op. |
| // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly, |
| // it should rather go through the Operator for better abstraction. |
| const Record &def; |
| |
| // The wrapper operator class for querying information from this op. |
| Operator op; |
| |
| // The C++ code builder for this op |
| OpClass opClass; |
| |
| // The format context for verification code generation. |
| FmtContext verifyCtx; |
| }; |
| } // end anonymous namespace |
| |
| OpEmitter::OpEmitter(const Operator &op) |
| : def(op.getDef()), op(op), |
| opClass(op.getCppClassName(), op.getExtraClassDeclaration()) { |
| verifyCtx.withOp("(*this->getOperation())"); |
| |
| genTraits(); |
| |
| // Generate C++ code for various op methods. The order here determines the |
| // methods in the generated file. |
| genOpAsmInterface(); |
| genOpNameGetter(); |
| genNamedOperandGetters(); |
| genNamedResultGetters(); |
| genNamedRegionGetters(); |
| genAttrGetters(); |
| genBuilder(); |
| genParser(); |
| genPrinter(); |
| genVerifier(); |
| genCanonicalizerDecls(); |
| genFolderDecls(); |
| genOpInterfaceMethods(); |
| |
| // If a dialect hook is registered for this op's dialect, emit dialect |
| // specific content. |
| auto dialectHookIt = dialectHooks->find(op.getDialectName()); |
| if (dialectHookIt != dialectHooks->end()) { |
| dialectHookIt->second(op, opClass); |
| } |
| } |
| |
| void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { |
| OpEmitter(op).emitDecl(os); |
| } |
| |
| void OpEmitter::emitDef(const Operator &op, raw_ostream &os) { |
| OpEmitter(op).emitDef(os); |
| } |
| |
| void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); } |
| |
| void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } |
| |
| void OpEmitter::genAttrGetters() { |
| FmtContext fctx; |
| fctx.withBuilder("mlir::Builder(this->getContext())"); |
| |
| // Emit the derived attribute body. |
| auto emitDerivedAttr = [&](StringRef name, Attribute attr) { |
| auto &method = opClass.newMethod(attr.getReturnType(), name); |
| auto &body = method.body(); |
| body << " " << attr.getDerivedCodeBody() << "\n"; |
| }; |
| |
| // Emit with return type specified. |
| auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { |
| auto &method = opClass.newMethod(attr.getReturnType(), name); |
| auto &body = method.body(); |
| body << " auto attr = " << name << "Attr();\n"; |
| if (attr.hasDefaultValue()) { |
| // Returns the default value if not set. |
| // TODO: this is inefficient, we are recreating the attribute for every |
| // call. This should be set instead. |
| std::string defaultValue = |
| tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()); |
| body << " if (!attr)\n return " |
| << tgfmt(attr.getConvertFromStorageCall(), |
| &fctx.withSelf(defaultValue)) |
| << ";\n"; |
| } |
| body << " return " |
| << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) |
| << ";\n"; |
| }; |
| |
| // Generate raw named accessor type. This is a wrapper class that allows |
| // referring to the attributes via accessors instead of having to use |
| // the string interface for better compile time verification. |
| auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { |
| auto &method = |
| opClass.newMethod(attr.getStorageType(), (name + "Attr").str()); |
| auto &body = method.body(); |
| body << " return this->getAttr(\"" << name << "\")."; |
| if (attr.isOptional() || attr.hasDefaultValue()) |
| body << "dyn_cast_or_null<"; |
| else |
| body << "cast<"; |
| body << attr.getStorageType() << ">();"; |
| }; |
| |
| for (auto &namedAttr : op.getAttributes()) { |
| const auto &name = namedAttr.name; |
| const auto &attr = namedAttr.attr; |
| if (attr.isDerivedAttr()) { |
| emitDerivedAttr(name, attr); |
| } else { |
| emitAttrWithStorageType(name, attr); |
| emitAttrWithReturnType(name, attr); |
| } |
| } |
| } |
| |
| // Generates the named operand getter methods for the given Operator `op` and |
| // puts them in `opClass`. Uses `rangeType` as the return type of getters that |
| // return a range of operands (individual operands are `Value ` and each |
| // element in the range must also be `Value `); use `rangeBeginCall` to get |
| // an iterator to the beginning of the operand range; use `rangeSizeCall` to |
| // obtain the number of operands. `getOperandCallPattern` contains the code |
| // necessary to obtain a single operand whose position will be substituted |
| // instead of |
| // "{0}" marker in the pattern. Note that the pattern should work for any kind |
| // of ops, in particular for one-operand ops that may not have the |
| // `getOperand(unsigned)` method. |
| static void generateNamedOperandGetters(const Operator &op, Class &opClass, |
| StringRef rangeType, |
| StringRef rangeBeginCall, |
| StringRef rangeSizeCall, |
| StringRef getOperandCallPattern) { |
| const int numOperands = op.getNumOperands(); |
| const int numVariadicOperands = op.getNumVariadicOperands(); |
| const int numNormalOperands = numOperands - numVariadicOperands; |
| |
| const auto *sameVariadicSize = |
| op.getTrait("OpTrait::SameVariadicOperandSize"); |
| const auto *attrSizedOperands = |
| op.getTrait("OpTrait::AttrSizedOperandSegments"); |
| |
| if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { |
| PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " |
| "specification over their sizes"); |
| } |
| |
| if (numVariadicOperands < 2 && attrSizedOperands) { |
| PrintFatalError(op.getLoc(), "op must have at least two variadic operands " |
| "to use 'AttrSizedOperandSegments' trait"); |
| } |
| |
| if (attrSizedOperands && sameVariadicSize) { |
| PrintFatalError(op.getLoc(), |
| "op cannot have both 'AttrSizedOperandSegments' and " |
| "'SameVariadicOperandSize' traits"); |
| } |
| |
| // First emit a "sink" getter method upon which we layer all nicer named |
| // getter methods. |
| auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); |
| |
| if (numVariadicOperands == 0) { |
| // We still need to match the return type, which is a range. |
| m.body() << " return {std::next(" << rangeBeginCall |
| << ", index), std::next(" << rangeBeginCall << ", index + 1)};"; |
| } else if (attrSizedOperands) { |
| m.body() << formatv(attrSizedSegmentValueRangeCalcCode, |
| "operand_segment_sizes", rangeBeginCall); |
| } else { |
| // Because the op can have arbitrarily interleaved variadic and non-variadic |
| // operands, we need to embed a list in the "sink" getter method for |
| // calculation at run-time. |
| llvm::SmallVector<StringRef, 4> isVariadic; |
| isVariadic.reserve(numOperands); |
| for (int i = 0; i < numOperands; ++i) { |
| isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic())); |
| } |
| std::string isVariadicList = llvm::join(isVariadic, ", "); |
| |
| m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, |
| numNormalOperands, numVariadicOperands, rangeSizeCall, |
| rangeBeginCall, "operand"); |
| } |
| |
| // Then we emit nicer named getter methods by redirecting to the "sink" getter |
| // method. |
| |
| for (int i = 0; i != numOperands; ++i) { |
| const auto &operand = op.getOperand(i); |
| if (operand.name.empty()) |
| continue; |
| |
| if (operand.isVariadic()) { |
| auto &m = opClass.newMethod(rangeType, operand.name); |
| m.body() << " return getODSOperands(" << i << ");"; |
| } else { |
| auto &m = opClass.newMethod("Value ", operand.name); |
| m.body() << " return *getODSOperands(" << i << ").begin();"; |
| } |
| } |
| } |
| |
| void OpEmitter::genNamedOperandGetters() { |
| if (op.getTrait("OpTrait::AttrSizedOperandSegments")) |
| opClass.setHasOperandAdaptorClass(false); |
| |
| generateNamedOperandGetters( |
| op, opClass, /*rangeType=*/"Operation::operand_range", |
| /*rangeBeginCall=*/"getOperation()->operand_begin()", |
| /*rangeSizeCall=*/"getOperation()->getNumOperands()", |
| /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); |
| } |
| |
| void OpEmitter::genNamedResultGetters() { |
| const int numResults = op.getNumResults(); |
| const int numVariadicResults = op.getNumVariadicResults(); |
| const int numNormalResults = numResults - numVariadicResults; |
| |
| // If we have more than one variadic results, we need more complicated logic |
| // to calculate the value range for each result. |
| |
| const auto *sameVariadicSize = op.getTrait("OpTrait::SameVariadicResultSize"); |
| const auto *attrSizedResults = |
| op.getTrait("OpTrait::AttrSizedResultSegments"); |
| |
| if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { |
| PrintFatalError(op.getLoc(), "op has multiple variadic results but no " |
| "specification over their sizes"); |
| } |
| |
| if (numVariadicResults < 2 && attrSizedResults) { |
| PrintFatalError(op.getLoc(), "op must have at least two variadic results " |
| "to use 'AttrSizedResultSegments' trait"); |
| } |
| |
| if (attrSizedResults && sameVariadicSize) { |
| PrintFatalError(op.getLoc(), |
| "op cannot have both 'AttrSizedResultSegments' and " |
| "'SameVariadicResultSize' traits"); |
| } |
| |
| auto &m = opClass.newMethod("Operation::result_range", "getODSResults", |
| "unsigned index"); |
| |
| if (numVariadicResults == 0) { |
| m.body() << " return {std::next(getOperation()->result_begin(), index), " |
| "std::next(getOperation()->result_begin(), index + 1)};"; |
| } else if (attrSizedResults) { |
| m.body() << formatv(attrSizedSegmentValueRangeCalcCode, |
| "result_segment_sizes", |
| "getOperation()->result_begin()"); |
| } else { |
| llvm::SmallVector<StringRef, 4> isVariadic; |
| isVariadic.reserve(numResults); |
| for (int i = 0; i < numResults; ++i) { |
| isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic())); |
| } |
| std::string isVariadicList = llvm::join(isVariadic, ", "); |
| |
| m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, |
| numNormalResults, numVariadicResults, |
| "getOperation()->getNumResults()", |
| "getOperation()->result_begin()", "result"); |
| } |
| |
| for (int i = 0; i != numResults; ++i) { |
| const auto &result = op.getResult(i); |
| if (result.name.empty()) |
| continue; |
| |
| if (result.isVariadic()) { |
| auto &m = opClass.newMethod("Operation::result_range", result.name); |
| m.body() << " return getODSResults(" << i << ");"; |
| } else { |
| auto &m = opClass.newMethod("Value ", result.name); |
| m.body() << " return *getODSResults(" << i << ").begin();"; |
| } |
| } |
| } |
| |
| void OpEmitter::genNamedRegionGetters() { |
| unsigned numRegions = op.getNumRegions(); |
| for (unsigned i = 0; i < numRegions; ++i) { |
| const auto ®ion = op.getRegion(i); |
| if (!region.name.empty()) { |
| auto &m = opClass.newMethod("Region &", region.name); |
| m.body() << formatv(" return this->getOperation()->getRegion({0});", i); |
| } |
| } |
| } |
| |
| static bool canGenerateUnwrappedBuilder(Operator &op) { |
| // If this op does not have native attributes at all, return directly to avoid |
| // redefining builders. |
| if (op.getNumNativeAttributes() == 0) |
| return false; |
| |
| bool canGenerate = false; |
| // We are generating builders that take raw values for attributes. We need to |
| // make sure the native attributes have a meaningful "unwrapped" value type |
| // different from the wrapped mlir::Attribute type to avoid redefining |
| // builders. This checks for the op has at least one such native attribute. |
| for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { |
| NamedAttribute &namedAttr = op.getAttribute(i); |
| if (canUseUnwrappedRawValue(namedAttr.attr)) { |
| canGenerate = true; |
| break; |
| } |
| } |
| return canGenerate; |
| } |
| |
| void OpEmitter::genSeparateArgParamBuilder() { |
| SmallVector<AttrParamKind, 2> attrBuilderType; |
| attrBuilderType.push_back(AttrParamKind::WrappedAttr); |
| if (canGenerateUnwrappedBuilder(op)) |
| attrBuilderType.push_back(AttrParamKind::UnwrappedValue); |
| |
| // Emit with separate builders with or without unwrapped attributes and/or |
| // inferring result type. |
| auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, |
| bool inferType) { |
| std::string paramList; |
| llvm::SmallVector<std::string, 4> resultNames; |
| buildParamList(paramList, resultNames, paramKind, attrType); |
| |
| auto &m = |
| opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); |
| auto &body = m.body(); |
| genCodeForAddingArgAndRegionForBuilder( |
| body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); |
| |
| // Push all result types to the operation state |
| |
| if (inferType) { |
| // Generate builder that infers type too. |
| // TODO(jpienaar): Subsume this with general checking if type can be |
| // infered automatically. |
| // TODO(jpienaar): Expand to handle regions. |
| body << formatv(R"( |
| SmallVector<Type, 2> inferedReturnTypes; |
| if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands, |
| {1}.attributes, /*regions=*/{{}, inferedReturnTypes))) |
| {1}.addTypes(inferedReturnTypes); |
| else |
| llvm::report_fatal_error("Failed to infer result type(s).");)", |
| opClass.getClassName(), builderOpState); |
| return; |
| } |
| |
| switch (paramKind) { |
| case TypeParamKind::None: |
| return; |
| case TypeParamKind::Separate: |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| body << " " << builderOpState << ".addTypes(" << resultNames[i] |
| << ");\n"; |
| } |
| return; |
| case TypeParamKind::Collective: |
| body << " " << builderOpState << ".addTypes(resultTypes);\n"; |
| return; |
| }; |
| llvm_unreachable("unhandled TypeParamKind"); |
| }; |
| |
| bool canInferType = |
| op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0; |
| for (auto attrType : attrBuilderType) { |
| emit(attrType, TypeParamKind::Separate, /*inferType=*/false); |
| if (canInferType) |
| emit(attrType, TypeParamKind::None, /*inferType=*/true); |
| // Emit separate arg build with collective type, unless there is only one |
| // variadic result, in which case the above would have already generated |
| // the same build method. |
| if (!(op.getNumResults() == 1 && op.getResult(0).isVariadic())) |
| emit(attrType, TypeParamKind::Collective, /*inferType=*/false); |
| } |
| } |
| |
| void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { |
| // If this op has a variadic result, we cannot generate this builder because |
| // we don't know how many results to create. |
| if (op.getNumVariadicResults() != 0) |
| return; |
| |
| int numResults = op.getNumResults(); |
| |
| // Signature |
| std::string params = |
| std::string("Builder *, OperationState &") + builderOpState + |
| ", ValueRange operands, ArrayRef<NamedAttribute> attributes"; |
| auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); |
| auto &body = m.body(); |
| |
| // Operands |
| body << " " << builderOpState << ".addOperands(operands);\n\n"; |
| |
| // Attributes |
| body << " " << builderOpState << ".addAttributes(attributes);\n"; |
| |
| // Create the correct number of regions |
| if (int numRegions = op.getNumRegions()) { |
| for (int i = 0; i < numRegions; ++i) |
| m.body() << " (void)" << builderOpState << ".addRegion();\n"; |
| } |
| |
| // Result types |
| SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()"); |
| body << " " << builderOpState << ".addTypes({" |
| << llvm::join(resultTypes, ", ") << "});\n\n"; |
| } |
| |
| void OpEmitter::genInferedTypeCollectiveParamBuilder() { |
| // TODO(jpienaar): Expand to support regions. |
| const char *params = |
| "Builder *builder, OperationState &{0}, " |
| "ValueRange operands, ArrayRef<NamedAttribute> attributes"; |
| auto &m = |
| opClass.newMethod("void", "build", formatv(params, builderOpState).str(), |
| OpMethod::MP_Static); |
| auto &body = m.body(); |
| body << formatv(R"( |
| SmallVector<Type, 2> inferedReturnTypes; |
| if (succeeded({0}::inferReturnTypes({1}.location, operands, attributes, |
| /*regions=*/{{}, inferedReturnTypes))) |
| build(builder, tblgen_state, inferedReturnTypes, operands, attributes); |
| else |
| llvm::report_fatal_error("Failed to infer result type(s).");)", |
| opClass.getClassName(), builderOpState); |
| } |
| |
| void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { |
| std::string paramList; |
| llvm::SmallVector<std::string, 4> resultNames; |
| buildParamList(paramList, resultNames, TypeParamKind::None); |
| |
| auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); |
| genCodeForAddingArgAndRegionForBuilder(m.body()); |
| |
| auto numResults = op.getNumResults(); |
| if (numResults == 0) |
| return; |
| |
| // Push all result types to the operation state |
| const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; |
| std::string resultType = |
| formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); |
| m.body() << " " << builderOpState << ".addTypes({" << resultType; |
| for (int i = 1; i != numResults; ++i) |
| m.body() << ", " << resultType; |
| m.body() << "});\n\n"; |
| } |
| |
| void OpEmitter::genUseAttrAsResultTypeBuilder() { |
| std::string params = |
| std::string("Builder *, OperationState &") + builderOpState + |
| ", ValueRange operands, ArrayRef<NamedAttribute> attributes"; |
| auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); |
| auto &body = m.body(); |
| |
| // Push all result types to the operation state |
| std::string resultType; |
| const auto &namedAttr = op.getAttribute(0); |
| |
| body << " for (auto attr : attributes) {\n"; |
| body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n"; |
| if (namedAttr.attr.isTypeAttr()) { |
| resultType = "attr.second.cast<TypeAttr>().getValue()"; |
| } else { |
| resultType = "attr.second.getType()"; |
| } |
| |
| // Operands |
| body << " " << builderOpState << ".addOperands(operands);\n\n"; |
| // Attributes |
| body << " " << builderOpState << ".addAttributes(attributes);\n"; |
| |
| // Result types |
| SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); |
| body << " " << builderOpState << ".addTypes({" |
| << llvm::join(resultTypes, ", ") << "});\n"; |
| body << " }\n"; |
| } |
| |
| void OpEmitter::genBuilder() { |
| // Handle custom builders if provided. |
| // TODO(antiagainst): Create wrapper class for OpBuilder to hide the native |
| // TableGen API calls here. |
| { |
| auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders")); |
| if (listInit) { |
| for (Init *init : listInit->getValues()) { |
| Record *builderDef = cast<DefInit>(init)->getDef(); |
| StringRef params = builderDef->getValueAsString("params"); |
| StringRef body = builderDef->getValueAsString("body"); |
| bool hasBody = !body.empty(); |
| |
| auto &method = |
| opClass.newMethod("void", "build", params, OpMethod::MP_Static, |
| /*declOnly=*/!hasBody); |
| if (hasBody) |
| method.body() << body; |
| } |
| } |
| if (op.skipDefaultBuilders()) { |
| if (!listInit || listInit->empty()) |
| PrintFatalError( |
| op.getLoc(), |
| "default builders are skipped and no custom builders provided"); |
| return; |
| } |
| } |
| |
| // Generate default builders that requires all result type, operands, and |
| // attributes as parameters. |
| |
| // We generate three classes of builders here: |
| // 1. one having a stand-alone parameter for each operand / attribute, and |
| genSeparateArgParamBuilder(); |
| // 2. one having an aggregated parameter for all result types / operands / |
| // attributes, and |
| genCollectiveParamBuilder(); |
| // 3. one having a stand-alone parameter for each operand and attribute, |
| // use the first operand or attribute's type as all result types |
| // to facilitate different call patterns. |
| if (op.getNumVariadicResults() == 0) { |
| if (op.getTrait("OpTrait::SameOperandsAndResultType")) { |
| genUseOperandAsResultTypeSeparateParamBuilder(); |
| genUseOperandAsResultTypeCollectiveParamBuilder(); |
| } |
| if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) |
| genUseAttrAsResultTypeBuilder(); |
| } |
| } |
| |
| void OpEmitter::genCollectiveParamBuilder() { |
| int numResults = op.getNumResults(); |
| int numVariadicResults = op.getNumVariadicResults(); |
| int numNonVariadicResults = numResults - numVariadicResults; |
| |
| int numOperands = op.getNumOperands(); |
| int numVariadicOperands = op.getNumVariadicOperands(); |
| int numNonVariadicOperands = numOperands - numVariadicOperands; |
| // Signature |
| std::string params = std::string("Builder *, OperationState &") + |
| builderOpState + |
| ", ArrayRef<Type> resultTypes, ValueRange operands, " |
| "ArrayRef<NamedAttribute> attributes"; |
| auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); |
| auto &body = m.body(); |
| |
| // Operands |
| if (numVariadicOperands == 0 || numNonVariadicOperands != 0) |
| body << " assert(operands.size()" |
| << (numVariadicOperands != 0 ? " >= " : " == ") |
| << numNonVariadicOperands |
| << "u && \"mismatched number of parameters\");\n"; |
| body << " " << builderOpState << ".addOperands(operands);\n\n"; |
| |
| // Attributes |
| body << " " << builderOpState << ".addAttributes(attributes);\n"; |
| |
| // Create the correct number of regions |
| if (int numRegions = op.getNumRegions()) { |
| for (int i = 0; i < numRegions; ++i) |
| m.body() << " (void)" << builderOpState << ".addRegion();\n"; |
| } |
| |
| // Result types |
| if (numVariadicResults == 0 || numNonVariadicResults != 0) |
| body << " assert(resultTypes.size()" |
| << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults |
| << "u && \"mismatched number of return types\");\n"; |
| body << " " << builderOpState << ".addTypes(resultTypes);\n"; |
| |
| // Generate builder that infers type too. |
| // TODO(jpienaar): Subsume this with general checking if type can be infered |
| // automatically. |
| // TODO(jpienaar): Expand to handle regions. |
| if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) |
| genInferedTypeCollectiveParamBuilder(); |
| } |
| |
| void OpEmitter::buildParamList(std::string ¶mList, |
| SmallVectorImpl<std::string> &resultTypeNames, |
| TypeParamKind typeParamKind, |
| AttrParamKind attrParamKind) { |
| resultTypeNames.clear(); |
| auto numResults = op.getNumResults(); |
| resultTypeNames.reserve(numResults); |
| |
| paramList = "Builder *tblgen_builder, OperationState &"; |
| paramList.append(builderOpState); |
| |
| switch (typeParamKind) { |
| case TypeParamKind::None: |
| break; |
| case TypeParamKind::Separate: { |
| // Add parameters for all return types |
| for (int i = 0; i < numResults; ++i) { |
| const auto &result = op.getResult(i); |
| std::string resultName = result.name; |
| if (resultName.empty()) |
| resultName = formatv("resultType{0}", i); |
| |
| paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type "); |
| paramList.append(resultName); |
| |
| resultTypeNames.emplace_back(std::move(resultName)); |
| } |
| } break; |
| case TypeParamKind::Collective: { |
| paramList.append(", ArrayRef<Type> resultTypes"); |
| resultTypeNames.push_back("resultTypes"); |
| } break; |
| } |
| |
| // Add parameters for all arguments (operands and attributes). |
| |
| int numOperands = 0; |
| int numAttrs = 0; |
| |
| int defaultValuedAttrStartIndex = op.getNumArgs(); |
| if (attrParamKind == AttrParamKind::UnwrappedValue) { |
| // Calculate the start index from which we can attach default values in the |
| // builder declaration. |
| for (int i = op.getNumArgs() - 1; i >= 0; --i) { |
| auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>(); |
| if (!namedAttr || !namedAttr->attr.hasDefaultValue()) |
| break; |
| |
| if (!canUseUnwrappedRawValue(namedAttr->attr)) |
| break; |
| |
| // Creating an APInt requires us to provide bitwidth, value, and |
| // signedness, which is complicated compared to others. Similarly |
| // for APFloat. |
| // TODO(b/144412160) Adjust the 'returnType' field of such attributes |
| // to support them. |
| StringRef retType = namedAttr->attr.getReturnType(); |
| if (retType == "APInt" || retType == "APFloat") |
| break; |
| |
| defaultValuedAttrStartIndex = i; |
| } |
| } |
| |
| for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| auto argument = op.getArg(i); |
| if (argument.is<tblgen::NamedTypeConstraint *>()) { |
| const auto &operand = op.getOperand(numOperands); |
| paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value "); |
| paramList.append(getArgumentName(op, numOperands)); |
| ++numOperands; |
| } else { |
| const auto &namedAttr = op.getAttribute(numAttrs); |
| const auto &attr = namedAttr.attr; |
| paramList.append(", "); |
| |
| if (attr.isOptional()) |
| paramList.append("/*optional*/"); |
| |
| switch (attrParamKind) { |
| case AttrParamKind::WrappedAttr: |
| paramList.append(attr.getStorageType()); |
| break; |
| case AttrParamKind::UnwrappedValue: |
| if (canUseUnwrappedRawValue(attr)) { |
| paramList.append(attr.getReturnType()); |
| } else { |
| paramList.append(attr.getStorageType()); |
| } |
| break; |
| } |
| paramList.append(" "); |
| paramList.append(namedAttr.name); |
| |
| // Attach default value if requested and possible. |
| if (attrParamKind == AttrParamKind::UnwrappedValue && |
| i >= defaultValuedAttrStartIndex) { |
| bool isString = attr.getReturnType() == "StringRef"; |
| paramList.append(" = "); |
| if (isString) |
| paramList.append("\""); |
| paramList.append(attr.getDefaultValue()); |
| if (isString) |
| paramList.append("\""); |
| } |
| ++numAttrs; |
| } |
| } |
| } |
| |
| void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, |
| bool isRawValueAttr) { |
| // Push all operands to the result |
| for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
| body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i) |
| << ");\n"; |
| } |
| |
| // Push all attributes to the result |
| for (const auto &namedAttr : op.getAttributes()) { |
| auto &attr = namedAttr.attr; |
| if (!attr.isDerivedAttr()) { |
| bool emitNotNullCheck = attr.isOptional(); |
| if (emitNotNullCheck) { |
| body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; |
| } |
| if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { |
| // If this is a raw value, then we need to wrap it in an Attribute |
| // instance. |
| FmtContext fctx; |
| fctx.withBuilder("(*tblgen_builder)"); |
| std::string value = |
| tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name); |
| body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, |
| namedAttr.name, value); |
| } else { |
| body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, |
| namedAttr.name); |
| } |
| if (emitNotNullCheck) { |
| body << " }\n"; |
| } |
| } |
| } |
| |
| // Create the correct number of regions |
| if (int numRegions = op.getNumRegions()) { |
| for (int i = 0; i < numRegions; ++i) |
| body << " (void)" << builderOpState << ".addRegion();\n"; |
| } |
| } |
| |
| void OpEmitter::genCanonicalizerDecls() { |
| if (!def.getValueAsBit("hasCanonicalizer")) |
| return; |
| |
| const char *const params = |
| "OwningRewritePatternList &results, MLIRContext *context"; |
| opClass.newMethod("void", "getCanonicalizationPatterns", params, |
| OpMethod::MP_Static, /*declOnly=*/true); |
| } |
| |
| void OpEmitter::genFolderDecls() { |
| bool hasSingleResult = |
| op.getNumResults() == 1 && op.getNumVariadicResults() == 0; |
| |
| if (def.getValueAsBit("hasFolder")) { |
| if (hasSingleResult) { |
| const char *const params = "ArrayRef<Attribute> operands"; |
| opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None, |
| /*declOnly=*/true); |
| } else { |
| const char *const params = "ArrayRef<Attribute> operands, " |
| "SmallVectorImpl<OpFoldResult> &results"; |
| opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None, |
| /*declOnly=*/true); |
| } |
| } |
| } |
| |
| void OpEmitter::genOpInterfaceMethods() { |
| for (const auto &trait : op.getTraits()) { |
| auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait); |
| if (!opTrait || !opTrait->shouldDeclareMethods()) |
| continue; |
| auto interface = opTrait->getOpInterface(); |
| for (auto method : interface.getMethods()) { |
| // Don't declare if the method has a body. |
| if (method.getBody()) |
| continue; |
| std::string args; |
| llvm::raw_string_ostream os(args); |
| mlir::interleaveComma(method.getArguments(), os, |
| [&](const OpInterfaceMethod::Argument &arg) { |
| os << arg.type << " " << arg.name; |
| }); |
| opClass.newMethod(method.getReturnType(), method.getName(), os.str(), |
| method.isStatic() ? OpMethod::MP_Static |
| : OpMethod::MP_None, |
| /*declOnly=*/true); |
| } |
| } |
| } |
| |
| void OpEmitter::genParser() { |
| if (!hasStringAttribute(def, "parser")) |
| return; |
| |
| auto &method = opClass.newMethod( |
| "ParseResult", "parse", "OpAsmParser &parser, OperationState &result", |
| OpMethod::MP_Static); |
| FmtContext fctx; |
| fctx.addSubst("cppClass", opClass.getClassName()); |
| auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); |
| method.body() << " " << tgfmt(parser, &fctx); |
| } |
| |
| void OpEmitter::genPrinter() { |
| auto valueInit = def.getValueInit("printer"); |
| CodeInit *codeInit = dyn_cast<CodeInit>(valueInit); |
| if (!codeInit) |
| return; |
| |
| auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p"); |
| FmtContext fctx; |
| fctx.addSubst("cppClass", opClass.getClassName()); |
| auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r"); |
| method.body() << " " << tgfmt(printer, &fctx); |
| } |
| |
| void OpEmitter::genVerifier() { |
| auto valueInit = def.getValueInit("verifier"); |
| CodeInit *codeInit = dyn_cast<CodeInit>(valueInit); |
| bool hasCustomVerify = codeInit && !codeInit->getValue().empty(); |
| |
| auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/""); |
| auto &body = method.body(); |
| |
| // Populate substitutions for attributes and named operands and results. |
| for (const auto &namedAttr : op.getAttributes()) |
| verifyCtx.addSubst(namedAttr.name, |
| formatv("this->getAttr(\"{0}\")", namedAttr.name)); |
| for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
| auto &value = op.getOperand(i); |
| // Skip from from first variadic operands for now. Else getOperand index |
| // used below doesn't match. |
| if (value.isVariadic()) |
| break; |
| if (!value.name.empty()) |
| verifyCtx.addSubst(value.name, |
| formatv("this->getOperation()->getOperand({0})", i)); |
| } |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| auto &value = op.getResult(i); |
| // Skip from from first variadic results for now. Else getResult index used |
| // below doesn't match. |
| if (value.isVariadic()) |
| break; |
| if (!value.name.empty()) |
| verifyCtx.addSubst(value.name, |
| formatv("this->getOperation()->getResult({0})", i)); |
| } |
| |
| // Verify the attributes have the correct type. |
| for (const auto &namedAttr : op.getAttributes()) { |
| const auto &attr = namedAttr.attr; |
| if (attr.isDerivedAttr()) |
| continue; |
| |
| auto attrName = namedAttr.name; |
| // Prefix with `tblgen_` to avoid hiding the attribute accessor. |
| auto varName = tblgenNamePrefix + attrName; |
| body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName, |
| attrName); |
| |
| bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); |
| if (allowMissingAttr) { |
| // If the attribute has a default value, then only verify the predicate if |
| // set. This does effectively assume that the default value is valid. |
| // TODO: verify the debug value is valid (perhaps in debug mode only). |
| body << " if (" << varName << ") {\n"; |
| } else { |
| body << " if (!" << varName |
| << ") return emitOpError(\"requires attribute '" << attrName |
| << "'\");\n {\n"; |
| } |
| |
| auto attrPred = attr.getPredicate(); |
| if (!attrPred.isNull()) { |
| body << tgfmt( |
| " if (!($0)) return emitOpError(\"attribute '$1' " |
| "failed to satisfy constraint: $2\");\n", |
| /*ctx=*/nullptr, |
| tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)), |
| attrName, attr.getDescription()); |
| } |
| |
| body << " }\n"; |
| } |
| |
| const char *code = R"( |
| auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}"); |
| auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements(); |
| if (numElements != {1}) {{ |
| return emitOpError("'{0}' attribute for specifying {2} segments " |
| "must have {1} elements"); |
| } |
| )"; |
| |
| for (auto &trait : op.getTraits()) { |
| if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) { |
| body << tgfmt(" if (!($0)) {\n " |
| "return emitOpError(\"failed to verify that $1\");\n }\n", |
| &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), |
| t->getDescription()); |
| } else if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) { |
| if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") { |
| body << formatv(code, "operand_segment_sizes", op.getNumOperands(), |
| "operand"); |
| } else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") { |
| body << formatv(code, "result_segment_sizes", op.getNumResults(), |
| "result"); |
| } |
| } |
| } |
| |
| // These should happen after we verified the traits because |
| // getODSOperands()/getODSResults() may depend on traits (e.g., |
| // AttrSizedOperandSegments/AttrSizedResultSegments). |
| genOperandResultVerifier(body, op.getOperands(), "operand"); |
| genOperandResultVerifier(body, op.getResults(), "result"); |
| |
| genRegionVerifier(body); |
| |
| if (hasCustomVerify) { |
| FmtContext fctx; |
| fctx.addSubst("cppClass", opClass.getClassName()); |
| auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r"); |
| body << " " << tgfmt(printer, &fctx); |
| } else { |
| body << " return mlir::success();\n"; |
| } |
| } |
| |
| void OpEmitter::genOperandResultVerifier(OpMethodBody &body, |
| Operator::value_range values, |
| StringRef valueKind) { |
| FmtContext fctx; |
| |
| body << " {\n"; |
| body << " unsigned index = 0; (void)index;\n"; |
| |
| for (auto staticValue : llvm::enumerate(values)) { |
| if (!staticValue.value().hasPredicate()) |
| continue; |
| |
| // Emit a loop to check all the dynamic values in the pack. |
| body << formatv(" for (Value v : getODS{0}{1}s({2})) {{\n", |
| // Capitalize the first letter to match the function name |
| valueKind.substr(0, 1).upper(), valueKind.substr(1), |
| staticValue.index()); |
| |
| auto constraint = staticValue.value().constraint; |
| |
| body << " (void)v;\n" |
| << " if (!(" |
| << tgfmt(constraint.getConditionTemplate(), |
| &fctx.withSelf("v.getType()")) |
| << ")) {\n" |
| << formatv(" return emitOpError(\"{0} #\") << index " |
| "<< \" must be {1}, but got \" << v.getType();\n", |
| valueKind, constraint.getDescription()) |
| << " }\n" // if |
| << " ++index;\n" |
| << " }\n"; // for |
| } |
| |
| body << " }\n"; |
| } |
| |
| void OpEmitter::genRegionVerifier(OpMethodBody &body) { |
| unsigned numRegions = op.getNumRegions(); |
| |
| // Verify this op has the correct number of regions |
| body << formatv( |
| " if (this->getOperation()->getNumRegions() != {0}) {\n " |
| "return emitOpError(\"has incorrect number of regions: expected {0} but " |
| "found \") << this->getOperation()->getNumRegions();\n }\n", |
| numRegions); |
| |
| for (unsigned i = 0; i < numRegions; ++i) { |
| const auto ®ion = op.getRegion(i); |
| |
| std::string name = formatv("#{0}", i); |
| if (!region.name.empty()) { |
| name += formatv(" ('{0}')", region.name); |
| } |
| |
| auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str(); |
| auto constraint = tgfmt(region.constraint.getConditionTemplate(), |
| &verifyCtx.withSelf(getRegion)) |
| .str(); |
| |
| body << formatv(" if (!({0})) {\n " |
| "return emitOpError(\"region {1} failed to verify " |
| "constraint: {2}\");\n }\n", |
| constraint, name, region.constraint.getDescription()); |
| } |
| } |
| |
| void OpEmitter::genTraits() { |
| int numResults = op.getNumResults(); |
| int numVariadicResults = op.getNumVariadicResults(); |
| |
| // Add return size trait. |
| if (numVariadicResults != 0) { |
| if (numResults == numVariadicResults) |
| opClass.addTrait("OpTrait::VariadicResults"); |
| else |
| opClass.addTrait("OpTrait::AtLeastNResults<" + |
| Twine(numResults - numVariadicResults) + ">::Impl"); |
| } else { |
| switch (numResults) { |
| case 0: |
| opClass.addTrait("OpTrait::ZeroResult"); |
| break; |
| case 1: |
| opClass.addTrait("OpTrait::OneResult"); |
| break; |
| default: |
| opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl"); |
| break; |
| } |
| } |
| |
| for (const auto &trait : op.getTraits()) { |
| if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait)) |
| opClass.addTrait(opTrait->getTrait()); |
| else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait)) |
| opClass.addTrait(opTrait->getTrait()); |
| } |
| |
| // Add variadic size trait and normal op traits. |
| int numOperands = op.getNumOperands(); |
| int numVariadicOperands = op.getNumVariadicOperands(); |
| |
| // Add operand size trait. |
| if (numVariadicOperands != 0) { |
| if (numOperands == numVariadicOperands) |
| opClass.addTrait("OpTrait::VariadicOperands"); |
| else |
| opClass.addTrait("OpTrait::AtLeastNOperands<" + |
| Twine(numOperands - numVariadicOperands) + ">::Impl"); |
| } else { |
| switch (numOperands) { |
| case 0: |
| opClass.addTrait("OpTrait::ZeroOperands"); |
| break; |
| case 1: |
| opClass.addTrait("OpTrait::OneOperand"); |
| break; |
| default: |
| opClass.addTrait("OpTrait::NOperands<" + Twine(numOperands) + ">::Impl"); |
| break; |
| } |
| } |
| } |
| |
| void OpEmitter::genOpNameGetter() { |
| auto &method = opClass.newMethod("StringRef", "getOperationName", |
| /*params=*/"", OpMethod::MP_Static); |
| method.body() << " return \"" << op.getOperationName() << "\";\n"; |
| } |
| |
| void OpEmitter::genOpAsmInterface() { |
| // If the user only has one results or specifically added the Asm trait, |
| // then don't generate it for them. We specifically only handle multi result |
| // operations, because the name of a single result in the common case is not |
| // interesting(generally 'result'/'output'/etc.). |
| // TODO: We could also add a flag to allow operations to opt in to this |
| // generation, even if they only have a single operation. |
| int numResults = op.getNumResults(); |
| if (numResults <= 1 || op.getTrait("OpAsmOpInterface::Trait")) |
| return; |
| |
| SmallVector<StringRef, 4> resultNames(numResults); |
| for (int i = 0; i != numResults; ++i) |
| resultNames[i] = op.getResultName(i); |
| |
| // Don't add the trait if none of the results have a valid name. |
| if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) |
| return; |
| opClass.addTrait("OpAsmOpInterface::Trait"); |
| |
| // Generate the right accessor for the number of results. |
| auto &method = opClass.newMethod("void", "getAsmResultNames", |
| "OpAsmSetValueNameFn setNameFn"); |
| auto &body = method.body(); |
| for (int i = 0; i != numResults; ++i) { |
| body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" |
| << " if (!llvm::empty(resultGroup" << i << "))\n" |
| << " setNameFn(*resultGroup" << i << ".begin(), \"" |
| << resultNames[i] << "\");\n"; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpOperandAdaptor emitter |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Helper class to emit Op operand adaptors to an output stream. Operand |
| // adaptors are wrappers around ArrayRef<Value> that provide named operand |
| // getters identical to those defined in the Op. |
| class OpOperandAdaptorEmitter { |
| public: |
| static void emitDecl(const Operator &op, raw_ostream &os); |
| static void emitDef(const Operator &op, raw_ostream &os); |
| |
| private: |
| explicit OpOperandAdaptorEmitter(const Operator &op); |
| |
| Class adapterClass; |
| }; |
| } // end namespace |
| |
| OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) |
| : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { |
| adapterClass.newField("ArrayRef<Value>", "tblgen_operands"); |
| auto &constructor = adapterClass.newConstructor("ArrayRef<Value> values"); |
| constructor.body() << " tblgen_operands = values;\n"; |
| |
| generateNamedOperandGetters(op, adapterClass, |
| /*rangeType=*/"ArrayRef<Value>", |
| /*rangeBeginCall=*/"tblgen_operands.begin()", |
| /*rangeSizeCall=*/"tblgen_operands.size()", |
| /*getOperandCallPattern=*/"tblgen_operands[{0}]"); |
| } |
| |
| void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { |
| OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os); |
| } |
| |
| void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { |
| OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os); |
| } |
| |
| // Emits the opcode enum and op classes. |
| static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os, |
| bool emitDecl) { |
| IfDefScope scope("GET_OP_CLASSES", os); |
| // First emit forward declaration for each class, this allows them to refer |
| // to each others in traits for example. |
| if (emitDecl) { |
| for (auto *def : defs) { |
| Operator op(*def); |
| os << "class " << op.getCppClassName() << ";\n"; |
| } |
| } |
| for (auto *def : defs) { |
| Operator op(*def); |
| const auto *attrSizedOperands = |
| op.getTrait("OpTrait::AttrSizedOperandSegments"); |
| if (emitDecl) { |
| os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); |
| // We cannot generate the operand adaptor class if operand getters depend |
| // on an attribute. |
| if (!attrSizedOperands) |
| OpOperandAdaptorEmitter::emitDecl(op, os); |
| OpEmitter::emitDecl(op, os); |
| } else { |
| os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); |
| if (!attrSizedOperands) |
| OpOperandAdaptorEmitter::emitDef(op, os); |
| OpEmitter::emitDef(op, os); |
| } |
| } |
| } |
| |
| // Emits a comma-separated list of the ops. |
| static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) { |
| IfDefScope scope("GET_OP_LIST", os); |
| |
| interleave( |
| // TODO: We are constructing the Operator wrapper instance just for |
| // getting it's qualified class name here. Reduce the overhead by having a |
| // lightweight version of Operator class just for that purpose. |
| defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, |
| [&os]() { os << ",\n"; }); |
| } |
| |
| static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { |
| emitSourceFileHeader("Op Declarations", os); |
| |
| const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); |
| emitOpClasses(defs, os, /*emitDecl=*/true); |
| |
| return false; |
| } |
| |
| static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { |
| emitSourceFileHeader("Op Definitions", os); |
| |
| const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); |
| emitOpList(defs, os); |
| emitOpClasses(defs, os, /*emitDecl=*/false); |
| |
| return false; |
| } |
| |
| static mlir::GenRegistration |
| genOpDecls("gen-op-decls", "Generate op declarations", |
| [](const RecordKeeper &records, raw_ostream &os) { |
| return emitOpDecls(records, os); |
| }); |
| |
| static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", |
| [](const RecordKeeper &records, |
| raw_ostream &os) { |
| return emitOpDefs(records, os); |
| }); |