blob: 48634a17ae32aee4e094b36764fd9b5a6135fc16 [file] [log] [blame] [edit]
//===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpDefinitionsGen uses the description of operations to generate IRDL
// definitions for ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/GenNameParser.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using tblgen::NamedTypeConstraint;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect");
static llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
static Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
MLIRContext *ctx = builder.getContext();
if (pred.isCombined()) {
auto combiner = pred.getDef().getValueAsDef("kind")->getName();
if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
std::vector<Value> constraints;
for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
}
if (combiner == "PredCombinerAnd") {
auto op =
irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
auto op =
irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
}
std::string condition = pred.getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = irdl::CPredOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, condition));
return op;
}
static Value typeToConstraint(OpBuilder &builder, Type type) {
MLIRContext *ctx = builder.getContext();
auto op =
irdl::IsOp::create(builder, UnknownLoc::get(ctx), TypeAttr::get(type));
return op.getOutput();
}
static Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
MLIRContext *ctx = builder.getContext();
auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, baseClass));
return op.getOutput();
}
static std::optional<Type> recordToType(MLIRContext *ctx,
const Record &predRec) {
if (predRec.isSubClassOf("I")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signless);
}
if (predRec.isSubClassOf("SI")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signed);
}
if (predRec.isSubClassOf("UI")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Unsigned);
}
// Index type
if (predRec.getName() == "Index") {
return IndexType::get(ctx);
}
// Float types
if (predRec.isSubClassOf("F")) {
auto width = predRec.getValueAsInt("bitwidth");
switch (width) {
case 16:
return Float16Type::get(ctx);
case 32:
return Float32Type::get(ctx);
case 64:
return Float64Type::get(ctx);
case 80:
return Float80Type::get(ctx);
case 128:
return Float128Type::get(ctx);
}
}
if (predRec.getName() == "NoneType") {
return NoneType::get(ctx);
}
if (predRec.getName() == "BF16") {
return BFloat16Type::get(ctx);
}
if (predRec.getName() == "TF32") {
return FloatTF32Type::get(ctx);
}
if (predRec.getName() == "F8E4M3FN") {
return Float8E4M3FNType::get(ctx);
}
if (predRec.getName() == "F8E5M2") {
return Float8E5M2Type::get(ctx);
}
if (predRec.getName() == "F8E4M3") {
return Float8E4M3Type::get(ctx);
}
if (predRec.getName() == "F8E4M3FNUZ") {
return Float8E4M3FNUZType::get(ctx);
}
if (predRec.getName() == "F8E4M3B11FNUZ") {
return Float8E4M3B11FNUZType::get(ctx);
}
if (predRec.getName() == "F8E5M2FNUZ") {
return Float8E5M2FNUZType::get(ctx);
}
if (predRec.getName() == "F8E3M4") {
return Float8E3M4Type::get(ctx);
}
if (predRec.isSubClassOf("Complex")) {
const Record *elementRec = predRec.getValueAsDef("elementType");
auto elementType = recordToType(ctx, *elementRec);
if (elementType.has_value()) {
return ComplexType::get(elementType.value());
}
}
return std::nullopt;
}
static Value createTypeConstraint(OpBuilder &builder,
tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
if (predRec.getName() == "AnyType") {
auto op = irdl::AnyOp::create(builder, UnknownLoc::get(ctx));
return op.getOutput();
}
if (predRec.isSubClassOf("TypeDef")) {
auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
if (dialect == selectedDialect) {
std::string combined = ("!" + predRec.getValueAsString("mnemonic")).str();
SmallVector<FlatSymbolRefAttr> nested = {
SymbolRefAttr::get(ctx, combined)};
auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), typeSymbol);
return op.getOutput();
}
std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, typeName));
return op.getOutput();
}
if (predRec.isSubClassOf("AnyTypeOf")) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
if (predRec.isSubClassOf("AllOfType")) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
// Integer types
if (predRec.getName() == "AnyInteger") {
auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, "!builtin.integer"));
return op.getOutput();
}
if (predRec.isSubClassOf("AnyI")) {
auto width = predRec.getValueAsInt("bitwidth");
std::vector<Value> types = {
typeToConstraint(builder,
IntegerType::get(ctx, width, IntegerType::Signless)),
typeToConstraint(builder,
IntegerType::get(ctx, width, IntegerType::Signed)),
typeToConstraint(builder,
IntegerType::get(ctx, width, IntegerType::Unsigned))};
auto op = irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), types);
return op.getOutput();
}
auto type = recordToType(ctx, predRec);
if (type.has_value()) {
return typeToConstraint(builder, type.value());
}
// Confined type
if (predRec.isSubClassOf("ConfinedType")) {
std::vector<Value> constraints;
constraints.push_back(createTypeConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
}
auto op = irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
return createPredicate(builder, constraint.getPredicate());
}
static Value createAttrConstraint(OpBuilder &builder,
tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("DefaultValuedAttr") ||
predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
predRec.isSubClassOf("OptionalAttr")) {
return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
}
if (predRec.isSubClassOf("ConfinedAttr")) {
std::vector<Value> constraints;
constraints.push_back(createAttrConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
for (const Record *child :
predRec.getValueAsListOfDefs("attrConstraints")) {
constraints.push_back(createPredicate(
builder, tblgen::Pred(child->getValueAsDef("predicate"))));
}
auto op = irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
if (predRec.isSubClassOf("AnyAttrOf")) {
std::vector<Value> constraints;
for (const Record *child :
predRec.getValueAsListOfDefs("allowedAttributes")) {
constraints.push_back(
createAttrConstraint(builder, tblgen::Constraint(child)));
}
auto op = irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
if (predRec.getName() == "AnyAttr") {
auto op = irdl::AnyOp::create(builder, UnknownLoc::get(ctx));
return op.getOutput();
}
if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
predRec.isSubClassOf("SignlessIntegerAttrBase") ||
predRec.isSubClassOf("SignedIntegerAttrBase") ||
predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
predRec.isSubClassOf("BoolAttr")) {
return baseToConstraint(builder, "!builtin.integer");
}
if (predRec.isSubClassOf("FloatAttrBase")) {
return baseToConstraint(builder, "!builtin.float");
}
if (predRec.isSubClassOf("StringBasedAttr")) {
return baseToConstraint(builder, "!builtin.string");
}
if (predRec.getName() == "UnitAttr") {
auto op =
irdl::IsOp::create(builder, UnknownLoc::get(ctx), UnitAttr::get(ctx));
return op.getOutput();
}
if (predRec.isSubClassOf("AttrDef")) {
auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
if (dialect == selectedDialect) {
std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
};
auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), typeSymbol);
return op.getOutput();
}
std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, typeName));
return op.getOutput();
}
return createPredicate(builder, constraint.getPredicate());
}
static Value createRegionConstraint(OpBuilder &builder,
tblgen::Region constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.getName() == "AnyRegion") {
ValueRange entryBlockArgs = {};
auto op =
irdl::RegionOp::create(builder, UnknownLoc::get(ctx), entryBlockArgs);
return op.getResult();
}
if (predRec.isSubClassOf("SizedRegion")) {
ValueRange entryBlockArgs = {};
auto ty = IntegerType::get(ctx, 32);
auto op = irdl::RegionOp::create(
builder, UnknownLoc::get(ctx), entryBlockArgs,
IntegerAttr::get(ty, predRec.getValueAsInt("blocks")));
return op.getResult();
}
return createPredicate(builder, constraint.getPredicate());
}
/// Returns the name of the operation without the dialect prefix.
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
return opName;
}
/// Returns the name of the type without the dialect prefix.
static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
return opName;
}
/// Returns the name of the attr without the dialect prefix.
static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
return opName;
}
/// Extract an operation to IRDL.
static irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
MLIRContext *ctx = builder.getContext();
StringRef opName = getOperatorName(tblgenOp);
irdl::OperationOp op = irdl::OperationOp::create(
builder, UnknownLoc::get(ctx), StringAttr::get(ctx, opName));
// Add the block in the region.
Block &opBlock = op.getBody().emplaceBlock();
OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock);
SmallDenseSet<StringRef> usedNames;
for (auto &namedCons : tblgenOp.getOperands())
usedNames.insert(namedCons.name);
for (auto &namedCons : tblgenOp.getResults())
usedNames.insert(namedCons.name);
for (auto &namedReg : tblgenOp.getRegions())
usedNames.insert(namedReg.name);
size_t generateCounter = 0;
auto generateName = [&](StringRef prefix) -> StringAttr {
SmallString<16> candidate;
do {
candidate.clear();
raw_svector_ostream candidateStream(candidate);
candidateStream << prefix << generateCounter;
generateCounter++;
} while (usedNames.contains(candidate));
return StringAttr::get(ctx, candidate);
};
auto normalizeName = [&](StringRef name) -> StringAttr {
if (name == "")
return generateName("unnamed");
return StringAttr::get(ctx, name);
};
auto getValues = [&](tblgen::Operator::const_value_range namedCons) {
SmallVector<Value> operands;
SmallVector<Attribute> names;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
names.push_back(normalizeName(namedCons.name));
irdl::VariadicityAttr var;
if (namedCons.isOptional())
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::optional);
else if (namedCons.isVariadic())
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::variadic);
else
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::single);
variadicity.push_back(var);
}
return std::make_tuple(operands, names, variadicity);
};
auto [operands, operandNames, operandVariadicity] =
getValues(tblgenOp.getOperands());
auto [results, resultNames, resultVariadicity] =
getValues(tblgenOp.getResults());
SmallVector<Value> attributes;
SmallVector<Attribute> attrNames;
for (auto namedAttr : tblgenOp.getAttributes()) {
if (namedAttr.attr.isOptional())
continue;
attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
}
SmallVector<Value> regions;
SmallVector<Attribute> regionNames;
for (auto namedRegion : tblgenOp.getRegions()) {
regions.push_back(
createRegionConstraint(consBuilder, namedRegion.constraint));
regionNames.push_back(normalizeName(namedRegion.name));
}
// Create the operands and results operations.
if (!operands.empty())
irdl::OperandsOp::create(consBuilder, UnknownLoc::get(ctx), operands,
ArrayAttr::get(ctx, operandNames),
operandVariadicity);
if (!results.empty())
irdl::ResultsOp::create(consBuilder, UnknownLoc::get(ctx), results,
ArrayAttr::get(ctx, resultNames),
resultVariadicity);
if (!attributes.empty())
irdl::AttributesOp::create(consBuilder, UnknownLoc::get(ctx), attributes,
ArrayAttr::get(ctx, attrNames));
if (!regions.empty())
irdl::RegionsOp::create(consBuilder, UnknownLoc::get(ctx), regions,
ArrayAttr::get(ctx, regionNames));
return op;
}
static irdl::TypeOp createIRDLType(OpBuilder &builder,
tblgen::TypeDef &tblgenType) {
MLIRContext *ctx = builder.getContext();
StringRef typeName = getTypeName(tblgenType);
std::string combined = ("!" + typeName).str();
irdl::TypeOp op = irdl::TypeOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, combined));
op.getBody().emplaceBlock();
return op;
}
static irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
tblgen::AttrDef &tblgenAttr) {
MLIRContext *ctx = builder.getContext();
StringRef attrName = getAttrName(tblgenAttr);
std::string combined = ("#" + attrName).str();
irdl::AttributeOp op = irdl::AttributeOp::create(
builder, UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
op.getBody().emplaceBlock();
return op;
}
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return irdl::DialectOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, selectedDialect));
}
static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) {
// Initialize.
MLIRContext ctx;
ctx.getOrLoadDialect<irdl::IRDLDialect>();
OpBuilder builder(&ctx);
// Create a module op and set it as the insertion point.
OwningOpRef<ModuleOp> module =
ModuleOp::create(builder, UnknownLoc::get(&ctx));
builder = builder.atBlockBegin(module->getBody());
// Create the dialect and insert it.
irdl::DialectOp dialect = createIRDLDialect(builder);
// Set insertion point to start of DialectOp.
builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());
for (const Record *type :
records.getAllDerivedDefinitionsIfDefined("TypeDef")) {
tblgen::TypeDef tblgenType(type);
if (tblgenType.getDialect().getName() != selectedDialect)
continue;
createIRDLType(builder, tblgenType);
}
for (const Record *attr :
records.getAllDerivedDefinitionsIfDefined("AttrDef")) {
tblgen::AttrDef tblgenAttr(attr);
if (tblgenAttr.getDialect().getName() != selectedDialect)
continue;
createIRDLAttr(builder, tblgenAttr);
}
for (const Record *def : records.getAllDerivedDefinitionsIfDefined("Op")) {
tblgen::Operator tblgenOp(def);
if (tblgenOp.getDialectName() != selectedDialect)
continue;
createIRDLOperation(builder, tblgenOp);
}
// Print the module.
module->print(os);
return false;
}
static mlir::GenRegistration
genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitDialectIRDLDefs(records, os);
});