blob: 04f81a4a2dce1abc2d575dfcf149eb291359f676 [file] [log] [blame] [edit]
//===- OmpOpGen.cpp - OpenMP dialect op specific generators ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OmpOpGen defines OpenMP dialect operation specific generators.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace llvm;
/// The code block defining the base mixin class for combining clause operand
/// structures.
static const char *const baseMixinClass = R"(
namespace detail {
template <typename... Mixins>
struct Clauses : public Mixins... {};
} // namespace detail
)";
/// The code block defining operation argument structures.
static const char *const operationArgStruct = R"(
using {0}Operands = detail::Clauses<{1}>;
)";
/// Remove multiple optional prefixes and suffixes from \c str.
///
/// Prefixes and suffixes are attempted to be removed once in the order they
/// appear in the \c prefixes and \c suffixes arguments. All prefixes are
/// processed before suffixes are. This means it will behave as shown in the
/// following example:
/// - str: "PrePreNameSuf1Suf2"
/// - prefixes: ["Pre"]
/// - suffixes: ["Suf1", "Suf2"]
/// - return: "PreNameSuf1"
static StringRef stripPrefixAndSuffix(StringRef str,
llvm::ArrayRef<StringRef> prefixes,
llvm::ArrayRef<StringRef> suffixes) {
for (StringRef prefix : prefixes)
if (str.starts_with(prefix))
str = str.drop_front(prefix.size());
for (StringRef suffix : suffixes)
if (str.ends_with(suffix))
str = str.drop_back(suffix.size());
return str;
}
/// Obtain the name of the OpenMP clause a given record inheriting
/// `OpenMP_Clause` refers to.
///
/// It supports direct and indirect `OpenMP_Clause` superclasses. Once the
/// `OpenMP_Clause` class the record is based on is found, the optional
/// "OpenMP_" prefix and "Skip" and "Clause" suffixes are removed to return only
/// the clause name, i.e. "OpenMP_CollapseClauseSkip" is returned as "Collapse".
static StringRef extractOmpClauseName(const Record *clause) {
const Record *ompClause = clause->getRecords().getClass("OpenMP_Clause");
assert(ompClause && "base OpenMP records expected to be defined");
StringRef clauseClassName;
SmallVector<const Record *, 1> clauseSuperClasses;
clause->getDirectSuperClasses(clauseSuperClasses);
// Check if OpenMP_Clause is a direct superclass.
for (const Record *superClass : clauseSuperClasses) {
if (superClass == ompClause) {
clauseClassName = clause->getName();
break;
}
}
// Support indirectly-inherited OpenMP_Clauses.
if (clauseClassName.empty()) {
for (auto [superClass, _] : clause->getSuperClasses()) {
if (superClass->isSubClassOf(ompClause)) {
clauseClassName = superClass->getName();
break;
}
}
}
assert(!clauseClassName.empty() && "clause name must be found");
// Keep only the OpenMP clause name itself for reporting purposes.
return stripPrefixAndSuffix(clauseClassName, /*prefixes=*/{"OpenMP_"},
/*suffixes=*/{"Skip", "Clause"});
}
/// Check that the given argument, identified by its name and initialization
/// value, is present in the \c arguments `dag`.
static bool verifyArgument(const DagInit *arguments, StringRef argName,
const Init *argInit) {
auto range = zip_equal(arguments->getArgNames(), arguments->getArgs());
return llvm::any_of(
range, [&](std::tuple<const llvm::StringInit *, const llvm::Init *> v) {
return std::get<0>(v)->getAsUnquotedString() == argName &&
std::get<1>(v) == argInit;
});
}
/// Check that the given string record value, identified by its \c opValueName,
/// is either undefined or empty in both the given operation and clause record
/// or its contents for the clause record are contained in the operation record.
/// Passing a non-empty \c clauseValueName enables checking values named
/// differently in the operation and clause records.
static bool verifyStringValue(const Record *op, const Record *clause,
StringRef opValueName,
StringRef clauseValueName = {}) {
auto opValue = op->getValueAsOptionalString(opValueName);
auto clauseValue = clause->getValueAsOptionalString(
clauseValueName.empty() ? opValueName : clauseValueName);
bool opHasValue = opValue && !opValue->trim().empty();
bool clauseHasValue = clauseValue && !clauseValue->trim().empty();
if (!opHasValue)
return !clauseHasValue;
return !clauseHasValue || opValue->contains(clauseValue->trim());
}
/// Verify that all fields of the given clause not explicitly ignored are
/// present in the corresponding operation field.
///
/// Print warnings or errors where this is not the case.
static void verifyClause(const Record *op, const Record *clause) {
StringRef clauseClassName = extractOmpClauseName(clause);
if (!clause->getValueAsBit("ignoreArgs")) {
const DagInit *opArguments = op->getValueAsDag("arguments");
const DagInit *arguments = clause->getValueAsDag("arguments");
for (auto [name, arg] :
zip(arguments->getArgNames(), arguments->getArgs())) {
if (!verifyArgument(opArguments, name->getAsUnquotedString(), arg))
PrintWarning(
op->getLoc(),
"'" + clauseClassName + "' clause-defined argument '" +
arg->getAsUnquotedString() + ":$" +
name->getAsUnquotedString() +
"' not present in operation. Consider `dag arguments = "
"!con(clausesArgs, ...)` or explicitly skipping this field.");
}
}
if (!clause->getValueAsBit("ignoreAsmFormat") &&
!verifyStringValue(op, clause, "assemblyFormat", "reqAssemblyFormat"))
PrintWarning(
op->getLoc(),
"'" + clauseClassName +
"' clause-defined `reqAssemblyFormat` not present in operation. "
"Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or "
"explicitly skipping this field.");
if (!clause->getValueAsBit("ignoreAsmFormat") &&
!verifyStringValue(op, clause, "assemblyFormat", "optAssemblyFormat"))
PrintWarning(
op->getLoc(),
"'" + clauseClassName +
"' clause-defined `optAssemblyFormat` not present in operation. "
"Consider concatenating `clauses[{Req,Opt}]AssemblyFormat` or "
"explicitly skipping this field.");
if (!clause->getValueAsBit("ignoreDesc") &&
!verifyStringValue(op, clause, "description"))
PrintError(op->getLoc(),
"'" + clauseClassName +
"' clause-defined `description` not present in operation. "
"Consider concatenating `clausesDescription` or explicitly "
"skipping this field.");
if (!clause->getValueAsBit("ignoreExtraDecl") &&
!verifyStringValue(op, clause, "extraClassDeclaration"))
PrintWarning(
op->getLoc(),
"'" + clauseClassName +
"' clause-defined `extraClassDeclaration` not present in "
"operation. Consider concatenating `clausesExtraClassDeclaration` "
"or explicitly skipping this field.");
}
/// Translate the type of an OpenMP clause's argument to its corresponding
/// representation for clause operand structures.
///
/// All kinds of values are represented as `mlir::Value` fields, whereas
/// attributes are represented based on their `storageType`.
///
/// \param[in] name The name of the argument.
/// \param[in] init The `DefInit` object representing the argument.
/// \param[out] nest Number of levels of array nesting associated with the
/// type. Must be initially set to 0.
/// \param[out] rank Rank (number of dimensions, if an array type) of the base
/// type. Must be initially set to 1.
///
/// \return the name of the base type to represent elements of the argument
/// type.
static StringRef translateArgumentType(ArrayRef<SMLoc> loc,
const StringInit *name, const Init *init,
int &nest, int &rank) {
const Record *def = cast<DefInit>(init)->getDef();
llvm::StringSet<> superClasses;
for (auto [sc, _] : def->getSuperClasses())
superClasses.insert(sc->getNameInitAsString());
// Handle wrapper-style superclasses.
if (superClasses.contains("OptionalAttr"))
return translateArgumentType(
loc, name, def->getValue("baseAttr")->getValue(), nest, rank);
if (superClasses.contains("TypedArrayAttrBase"))
return translateArgumentType(
loc, name, def->getValue("elementAttr")->getValue(), ++nest, rank);
// Handle ElementsAttrBase superclasses.
if (superClasses.contains("ElementsAttrBase")) {
// TODO: Obtain the rank from ranked types.
++nest;
if (superClasses.contains("IntElementsAttrBase"))
return "::llvm::APInt";
if (superClasses.contains("FloatElementsAttr") ||
superClasses.contains("RankedFloatElementsAttr"))
return "::llvm::APFloat";
if (superClasses.contains("DenseArrayAttrBase"))
return stripPrefixAndSuffix(def->getValueAsString("returnType"),
{"::llvm::ArrayRef<"}, {">"});
// Decrease the nesting depth in the case where the base type cannot be
// inferred, so that the bare storageType is used instead of a vector.
--nest;
PrintWarning(
loc,
"could not infer array-like attribute element type for argument '" +
name->getAsUnquotedString() + "', will use bare `storageType`");
}
// Handle simple attribute and value types.
[[maybe_unused]] bool isAttr = superClasses.contains("Attr");
bool isValue = superClasses.contains("TypeConstraint");
if (superClasses.contains("Variadic"))
++nest;
if (isValue) {
assert(!isAttr &&
"argument can't be simultaneously a value and an attribute");
return "::mlir::Value";
}
assert(isAttr && "argument must be an attribute if it's not a value");
return nest > 0 ? "::mlir::Attribute"
: def->getValueAsString("storageType").trim();
}
/// Generate the structure that represents the arguments of the given \c clause
/// record of type \c OpenMP_Clause.
///
/// It will contain a field for each argument, using the same name translated to
/// camel case and the corresponding base type as returned by
/// translateArgumentType() optionally wrapped in one or more llvm::SmallVector.
///
/// An additional field containing a tuple of integers to hold the size of each
/// dimension will also be created for multi-rank types. This is not yet
/// supported.
static void genClauseOpsStruct(const Record *clause, raw_ostream &os) {
if (clause->isAnonymous())
return;
StringRef clauseName = extractOmpClauseName(clause);
os << "struct " << clauseName << "ClauseOps {\n";
const DagInit *arguments = clause->getValueAsDag("arguments");
for (auto [name, arg] :
zip_equal(arguments->getArgNames(), arguments->getArgs())) {
int nest = 0, rank = 1;
StringRef baseType =
translateArgumentType(clause->getLoc(), name, arg, nest, rank);
std::string fieldName =
convertToCamelFromSnakeCase(name->getAsUnquotedString(),
/*capitalizeFirst=*/false);
os << formatv(" {0}{1}{2} {3};\n",
fmt_repeat("::llvm::SmallVector<", nest), baseType,
fmt_repeat(">", nest), fieldName);
if (rank > 1) {
assert(nest >= 1 && "must be nested if it's a ranked type");
os << formatv(" {0}::std::tuple<{1}int>{2} {3}Dims;\n",
fmt_repeat("::llvm::SmallVector<", nest - 1),
fmt_repeat("int, ", rank - 1), fmt_repeat(">", nest - 1),
fieldName);
}
}
os << "};\n";
}
/// Generate the structure that represents the clause-related arguments of the
/// given \c op record of type \c OpenMP_Op.
///
/// This structure will be defined in terms of the clause operand structures
/// associated to the clauses of the operation.
static void genOperandsDef(const Record *op, raw_ostream &os) {
if (op->isAnonymous())
return;
SmallVector<std::string> clauseNames;
for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
clauseNames.push_back((extractOmpClauseName(clause) + "ClauseOps").str());
StringRef opName = stripPrefixAndSuffix(
op->getName(), /*prefixes=*/{"OpenMP_"}, /*suffixes=*/{"Op"});
os << formatv(operationArgStruct, opName, join(clauseNames, ", "));
}
/// Verify that all properties of `OpenMP_Clause`s of records deriving from
/// `OpenMP_Op`s have been inherited by the latter.
static bool verifyDecls(const RecordKeeper &records, raw_ostream &) {
for (const Record *op : records.getAllDerivedDefinitions("OpenMP_Op")) {
for (const Record *clause : op->getValueAsListOfDefs("clauseList"))
verifyClause(op, clause);
}
return false;
}
/// Generate structures to represent clause-related operands, based on existing
/// `OpenMP_Clause` definitions and aggregate them into operation-specific
/// structures according to the `clauses` argument of each definition deriving
/// from `OpenMP_Op`.
static bool genClauseOps(const RecordKeeper &records, raw_ostream &os) {
mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp");
for (const Record *clause : records.getAllDerivedDefinitions("OpenMP_Clause"))
genClauseOpsStruct(clause, os);
// Produce base mixin class.
os << baseMixinClass;
for (const Record *op : records.getAllDerivedDefinitions("OpenMP_Op"))
genOperandsDef(op, os);
return false;
}
// Registers the generator to mlir-tblgen.
static mlir::GenRegistration
verifyOpenmpOps("verify-openmp-ops",
"Verify OpenMP operations (produce no output file)",
verifyDecls);
static mlir::GenRegistration
genOpenmpClauseOps("gen-openmp-clause-ops",
"Generate OpenMP clause operand structures",
genClauseOps);