| //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===// |
| // |
| // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Manages the loading of MLIR objects from IRDL operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/IRDL/IRDLLoading.h" |
| #include "mlir/Dialect/IRDL/IR/IRDL.h" |
| #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" |
| #include "mlir/Dialect/IRDL/IRDLSymbols.h" |
| #include "mlir/Dialect/IRDL/IRDLVerifiers.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/ExtensibleDialect.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallPtrSet.h" |
| #include "llvm/Support/SMLoc.h" |
| #include <numeric> |
| |
| using namespace mlir; |
| using namespace mlir::irdl; |
| |
| /// Verify that the given list of parameters satisfy the given constraints. |
| /// This encodes the logic of the verification method for attributes and types |
| /// defined with IRDL. |
| static LogicalResult |
| irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError, |
| ArrayRef<Attribute> params, |
| ArrayRef<std::unique_ptr<Constraint>> constraints, |
| ArrayRef<size_t> paramConstraints) { |
| if (params.size() != paramConstraints.size()) { |
| emitError() << "expected " << paramConstraints.size() |
| << " type arguments, but had " << params.size(); |
| return failure(); |
| } |
| |
| ConstraintVerifier verifier(constraints); |
| |
| // Check that each parameter satisfies its constraint. |
| for (auto [i, param] : enumerate(params)) |
| if (failed(verifier.verify(emitError, param, paramConstraints[i]))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| /// Get the operand segment sizes from the attribute dictionary. |
| LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, |
| StringRef attrName, unsigned numElements, |
| ArrayRef<Variadicity> variadicities, |
| SmallVectorImpl<int> &segmentSizes) { |
| // Get the segment sizes attribute, and check that it is of the right type. |
| Attribute segmentSizesAttr = op->getAttr(attrName); |
| if (!segmentSizesAttr) { |
| return op->emitError() << "'" << attrName |
| << "' attribute is expected but not provided"; |
| } |
| |
| auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr); |
| if (!denseSegmentSizes) { |
| return op->emitError() << "'" << attrName |
| << "' attribute is expected to be a dense i32 array"; |
| } |
| |
| if (denseSegmentSizes.size() != (int64_t)variadicities.size()) { |
| return op->emitError() << "'" << attrName << "' attribute for specifying " |
| << elemName << " segments must have " |
| << variadicities.size() << " elements, but got " |
| << denseSegmentSizes.size(); |
| } |
| |
| // Check that the segment sizes are corresponding to the given variadicities, |
| for (auto [i, segmentSize, variadicity] : |
| enumerate(denseSegmentSizes.asArrayRef(), variadicities)) { |
| if (segmentSize < 0) |
| return op->emitError() |
| << "'" << attrName << "' attribute for specifying " << elemName |
| << " segments must have non-negative values"; |
| if (variadicity == Variadicity::single && segmentSize != 1) |
| return op->emitError() << "element " << i << " in '" << attrName |
| << "' attribute must be equal to 1"; |
| |
| if (variadicity == Variadicity::optional && segmentSize > 1) |
| return op->emitError() << "element " << i << " in '" << attrName |
| << "' attribute must be equal to 0 or 1"; |
| |
| segmentSizes.push_back(segmentSize); |
| } |
| |
| // Check that the sum of the segment sizes is equal to the number of elements. |
| int32_t sum = 0; |
| for (int32_t segmentSize : denseSegmentSizes.asArrayRef()) |
| sum += segmentSize; |
| if (sum != static_cast<int32_t>(numElements)) |
| return op->emitError() << "sum of elements in '" << attrName |
| << "' attribute must be equal to the number of " |
| << elemName << "s"; |
| |
| return success(); |
| } |
| |
| /// Compute the segment sizes of the given element (operands, results). |
| /// If the operation has more than two non-single elements (optional or |
| /// variadic), then get the segment sizes from the attribute dictionary. |
| /// Otherwise, compute the segment sizes from the number of elements. |
| /// `elemName` should be either `"operand"` or `"result"`. |
| LogicalResult getSegmentSizes(Operation *op, StringRef elemName, |
| StringRef attrName, unsigned numElements, |
| ArrayRef<Variadicity> variadicities, |
| SmallVectorImpl<int> &segmentSizes) { |
| // If we have more than one non-single variadicity, we need to get the |
| // segment sizes from the attribute dictionary. |
| int numberNonSingle = count_if( |
| variadicities, [](Variadicity v) { return v != Variadicity::single; }); |
| if (numberNonSingle > 1) |
| return getSegmentSizesFromAttr(op, elemName, attrName, numElements, |
| variadicities, segmentSizes); |
| |
| // If we only have single variadicities, the segments sizes are all 1. |
| if (numberNonSingle == 0) { |
| if (numElements != variadicities.size()) { |
| return op->emitError() << "op expects exactly " << variadicities.size() |
| << " " << elemName << "s, but got " << numElements; |
| } |
| for (size_t i = 0, e = variadicities.size(); i < e; ++i) |
| segmentSizes.push_back(1); |
| return success(); |
| } |
| |
| assert(numberNonSingle == 1); |
| |
| // There is exactly one non-single element, so we can |
| // compute its size and check that it is valid. |
| int nonSingleSegmentSize = static_cast<int>(numElements) - |
| static_cast<int>(variadicities.size()) + 1; |
| |
| if (nonSingleSegmentSize < 0) { |
| return op->emitError() << "op expects at least " << variadicities.size() - 1 |
| << " " << elemName << "s, but got " << numElements; |
| } |
| |
| // Add the segment sizes. |
| for (Variadicity variadicity : variadicities) { |
| if (variadicity == Variadicity::single) { |
| segmentSizes.push_back(1); |
| continue; |
| } |
| |
| // If we have an optional element, we should check that it represents |
| // zero or one elements. |
| if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional) |
| return op->emitError() << "op expects at most " << variadicities.size() |
| << " " << elemName << "s, but got " << numElements; |
| |
| segmentSizes.push_back(nonSingleSegmentSize); |
| } |
| |
| return success(); |
| } |
| |
| /// Compute the segment sizes of the given operands. |
| /// If the operation has more than two non-single operands (optional or |
| /// variadic), then get the segment sizes from the attribute dictionary. |
| /// Otherwise, compute the segment sizes from the number of operands. |
| LogicalResult getOperandSegmentSizes(Operation *op, |
| ArrayRef<Variadicity> variadicities, |
| SmallVectorImpl<int> &segmentSizes) { |
| return getSegmentSizes(op, "operand", "operand_segment_sizes", |
| op->getNumOperands(), variadicities, segmentSizes); |
| } |
| |
| /// Compute the segment sizes of the given results. |
| /// If the operation has more than two non-single results (optional or |
| /// variadic), then get the segment sizes from the attribute dictionary. |
| /// Otherwise, compute the segment sizes from the number of results. |
| LogicalResult getResultSegmentSizes(Operation *op, |
| ArrayRef<Variadicity> variadicities, |
| SmallVectorImpl<int> &segmentSizes) { |
| return getSegmentSizes(op, "result", "result_segment_sizes", |
| op->getNumResults(), variadicities, segmentSizes); |
| } |
| |
| /// Verify that the given operation satisfies the given constraints. |
| /// This encodes the logic of the verification method for operations defined |
| /// with IRDL. |
| static LogicalResult irdlOpVerifier( |
| Operation *op, ConstraintVerifier &verifier, |
| ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity, |
| ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity, |
| const DenseMap<StringAttr, size_t> &attributeConstrs) { |
| // Get the segment sizes for the operands. |
| // This will check that the number of operands is correct. |
| SmallVector<int> operandSegmentSizes; |
| if (failed( |
| getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes))) |
| return failure(); |
| |
| // Get the segment sizes for the results. |
| // This will check that the number of results is correct. |
| SmallVector<int> resultSegmentSizes; |
| if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes))) |
| return failure(); |
| |
| auto emitError = [op] { return op->emitError(); }; |
| |
| /// Сheck that we have all needed attributes passed |
| /// and they satisfy the constraints. |
| DictionaryAttr actualAttrs = op->getAttrDictionary(); |
| |
| for (auto [name, constraint] : attributeConstrs) { |
| /// First, check if the attribute actually passed. |
| std::optional<NamedAttribute> actual = actualAttrs.getNamed(name); |
| if (!actual.has_value()) |
| return op->emitOpError() |
| << "attribute " << name << " is expected but not provided"; |
| |
| /// Then, check if the attribute value satisfies the constraint. |
| if (failed(verifier.verify({emitError}, actual->getValue(), constraint))) |
| return failure(); |
| } |
| |
| // Check that all operands satisfy the constraints |
| int operandIdx = 0; |
| for (auto [defIndex, segmentSize] : enumerate(operandSegmentSizes)) { |
| for (int i = 0; i < segmentSize; i++) { |
| if (failed(verifier.verify( |
| {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]), |
| operandConstrs[defIndex]))) |
| return failure(); |
| ++operandIdx; |
| } |
| } |
| |
| // Check that all results satisfy the constraints |
| int resultIdx = 0; |
| for (auto [defIndex, segmentSize] : enumerate(resultSegmentSizes)) { |
| for (int i = 0; i < segmentSize; i++) { |
| if (failed(verifier.verify({emitError}, |
| TypeAttr::get(op->getResultTypes()[resultIdx]), |
| resultConstrs[defIndex]))) |
| return failure(); |
| ++resultIdx; |
| } |
| } |
| |
| return success(); |
| } |
| |
| static LogicalResult irdlRegionVerifier( |
| Operation *op, ConstraintVerifier &verifier, |
| ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) { |
| if (op->getNumRegions() != regionsConstraints.size()) { |
| return op->emitOpError() |
| << "unexpected number of regions: expected " |
| << regionsConstraints.size() << " but got " << op->getNumRegions(); |
| } |
| |
| for (auto [constraint, region] : |
| llvm::zip(regionsConstraints, op->getRegions())) |
| if (failed(constraint->verify(region, verifier))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| llvm::unique_function<LogicalResult(Operation *) const> |
| mlir::irdl::createVerifier( |
| OperationOp op, |
| const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
| const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
| &attrs) { |
| // Resolve SSA values to verifier constraint slots |
| SmallVector<Value> constrToValue; |
| SmallVector<Value> regionToValue; |
| for (Operation &op : op->getRegion(0).getOps()) { |
| if (isa<VerifyConstraintInterface>(op)) { |
| if (op.getNumResults() != 1) { |
| op.emitError() |
| << "IRDL constraint operations must have exactly one result"; |
| return nullptr; |
| } |
| constrToValue.push_back(op.getResult(0)); |
| } |
| if (isa<VerifyRegionInterface>(op)) { |
| if (op.getNumResults() != 1) { |
| op.emitError() |
| << "IRDL constraint operations must have exactly one result"; |
| return nullptr; |
| } |
| regionToValue.push_back(op.getResult(0)); |
| } |
| } |
| |
| // Build the verifiers for each constraint slot |
| SmallVector<std::unique_ptr<Constraint>> constraints; |
| for (Value v : constrToValue) { |
| VerifyConstraintInterface op = |
| cast<VerifyConstraintInterface>(v.getDefiningOp()); |
| std::unique_ptr<Constraint> verifier = |
| op.getVerifier(constrToValue, types, attrs); |
| if (!verifier) |
| return nullptr; |
| constraints.push_back(std::move(verifier)); |
| } |
| |
| // Build region constraints |
| SmallVector<std::unique_ptr<RegionConstraint>> regionConstraints; |
| for (Value v : regionToValue) { |
| VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp()); |
| std::unique_ptr<RegionConstraint> verifier = |
| op.getVerifier(constrToValue, types, attrs); |
| regionConstraints.push_back(std::move(verifier)); |
| } |
| |
| SmallVector<size_t> operandConstraints; |
| SmallVector<Variadicity> operandVariadicity; |
| |
| // Gather which constraint slots correspond to operand constraints |
| auto operandsOp = op.getOp<OperandsOp>(); |
| if (operandsOp.has_value()) { |
| operandConstraints.reserve(operandsOp->getArgs().size()); |
| for (Value operand : operandsOp->getArgs()) { |
| for (auto [i, constr] : enumerate(constrToValue)) { |
| if (constr == operand) { |
| operandConstraints.push_back(i); |
| break; |
| } |
| } |
| } |
| |
| // Gather the variadicities of each operand |
| for (VariadicityAttr attr : operandsOp->getVariadicity()) |
| operandVariadicity.push_back(attr.getValue()); |
| } |
| |
| SmallVector<size_t> resultConstraints; |
| SmallVector<Variadicity> resultVariadicity; |
| |
| // Gather which constraint slots correspond to result constraints |
| auto resultsOp = op.getOp<ResultsOp>(); |
| if (resultsOp.has_value()) { |
| resultConstraints.reserve(resultsOp->getArgs().size()); |
| for (Value result : resultsOp->getArgs()) { |
| for (auto [i, constr] : enumerate(constrToValue)) { |
| if (constr == result) { |
| resultConstraints.push_back(i); |
| break; |
| } |
| } |
| } |
| |
| // Gather the variadicities of each result |
| for (Attribute attr : resultsOp->getVariadicity()) |
| resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue()); |
| } |
| |
| // Gather which constraint slots correspond to attributes constraints |
| DenseMap<StringAttr, size_t> attributeConstraints; |
| auto attributesOp = op.getOp<AttributesOp>(); |
| if (attributesOp.has_value()) { |
| const Operation::operand_range values = attributesOp->getAttributeValues(); |
| const ArrayAttr names = attributesOp->getAttributeValueNames(); |
| |
| for (const auto &[name, value] : llvm::zip(names, values)) { |
| for (auto [i, constr] : enumerate(constrToValue)) { |
| if (constr == value) { |
| attributeConstraints[cast<StringAttr>(name)] = i; |
| break; |
| } |
| } |
| } |
| } |
| |
| return |
| [constraints{std::move(constraints)}, |
| regionConstraints{std::move(regionConstraints)}, |
| operandConstraints{std::move(operandConstraints)}, |
| operandVariadicity{std::move(operandVariadicity)}, |
| resultConstraints{std::move(resultConstraints)}, |
| resultVariadicity{std::move(resultVariadicity)}, |
| attributeConstraints{std::move(attributeConstraints)}](Operation *op) { |
| ConstraintVerifier verifier(constraints); |
| const LogicalResult opVerifierResult = irdlOpVerifier( |
| op, verifier, operandConstraints, operandVariadicity, |
| resultConstraints, resultVariadicity, attributeConstraints); |
| const LogicalResult opRegionVerifierResult = |
| irdlRegionVerifier(op, verifier, regionConstraints); |
| return LogicalResult::success(opVerifierResult.succeeded() && |
| opRegionVerifierResult.succeeded()); |
| }; |
| } |
| |
| /// Define and load an operation represented by a `irdl.operation` |
| /// operation. |
| static WalkResult loadOperation( |
| OperationOp op, ExtensibleDialect *dialect, |
| const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
| const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
| &attrs) { |
| |
| // IRDL does not support defining custom parsers or printers. |
| auto parser = [](OpAsmParser &parser, OperationState &result) { |
| return failure(); |
| }; |
| auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) { |
| printer.printGenericOp(op); |
| }; |
| |
| auto verifier = createVerifier(op, types, attrs); |
| if (!verifier) |
| return WalkResult::interrupt(); |
| |
| // IRDL supports only checking number of blocks and argument constraints |
| // It is done in the main verifier to reuse `ConstraintVerifier` context |
| auto regionVerifier = [](Operation *op) { return LogicalResult::success(); }; |
| |
| auto opDef = DynamicOpDefinition::get( |
| op.getName(), dialect, std::move(verifier), std::move(regionVerifier), |
| std::move(parser), std::move(printer)); |
| dialect->registerDynamicOp(std::move(opDef)); |
| |
| return WalkResult::advance(); |
| } |
| |
| /// Get the verifier of a type or attribute definition. |
| /// Return nullptr if the definition is invalid. |
| static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier( |
| Operation *attrOrTypeDef, ExtensibleDialect *dialect, |
| DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
| DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { |
| assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) && |
| "Expected an attribute or type definition"); |
| |
| // Resolve SSA values to verifier constraint slots |
| SmallVector<Value> constrToValue; |
| for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) { |
| if (isa<VerifyConstraintInterface>(op)) { |
| assert(op.getNumResults() == 1 && |
| "IRDL constraint operations must have exactly one result"); |
| constrToValue.push_back(op.getResult(0)); |
| } |
| } |
| |
| // Build the verifiers for each constraint slot |
| SmallVector<std::unique_ptr<Constraint>> constraints; |
| for (Value v : constrToValue) { |
| VerifyConstraintInterface op = |
| cast<VerifyConstraintInterface>(v.getDefiningOp()); |
| std::unique_ptr<Constraint> verifier = |
| op.getVerifier(constrToValue, types, attrs); |
| if (!verifier) |
| return {}; |
| constraints.push_back(std::move(verifier)); |
| } |
| |
| // Get the parameter definitions. |
| std::optional<ParametersOp> params; |
| if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef)) |
| params = attr.getOp<ParametersOp>(); |
| else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef)) |
| params = type.getOp<ParametersOp>(); |
| |
| // Gather which constraint slots correspond to parameter constraints |
| SmallVector<size_t> paramConstraints; |
| if (params.has_value()) { |
| paramConstraints.reserve(params->getArgs().size()); |
| for (Value param : params->getArgs()) { |
| for (auto [i, constr] : enumerate(constrToValue)) { |
| if (constr == param) { |
| paramConstraints.push_back(i); |
| break; |
| } |
| } |
| } |
| } |
| |
| auto verifier = [paramConstraints{std::move(paramConstraints)}, |
| constraints{std::move(constraints)}]( |
| function_ref<InFlightDiagnostic()> emitError, |
| ArrayRef<Attribute> params) { |
| return irdlAttrOrTypeVerifier(emitError, params, constraints, |
| paramConstraints); |
| }; |
| |
| // While the `std::move` is not required, not adding it triggers a bug in |
| // clang-10. |
| return std::move(verifier); |
| } |
| |
| /// Get the possible bases of a constraint. Return `true` if all bases can |
| /// potentially be matched. |
| /// A base is a type or an attribute definition. For instance, the base of |
| /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`. |
| /// This function returns the following information through arguments: |
| /// - `paramIds`: the set of type or attribute IDs that are used as bases. |
| /// - `paramIrdlOps`: the set of IRDL operations that are used as bases. |
| /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is` |
| /// constraints. |
| static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds, |
| SmallPtrSet<Operation *, 4> ¶mIrdlOps, |
| SmallPtrSet<TypeID, 4> &isIds) { |
| // For `irdl.any_of`, we get the bases from all its arguments. |
| if (auto anyOf = dyn_cast<AnyOfOp>(op)) { |
| bool hasAny = false; |
| for (Value arg : anyOf.getArgs()) |
| hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds); |
| return hasAny; |
| } |
| |
| // For `irdl.all_of`, we get the bases from the first argument. |
| // This is restrictive, but we can relax it later if needed. |
| if (auto allOf = dyn_cast<AllOfOp>(op)) |
| return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps, |
| isIds); |
| |
| // For `irdl.parametric`, we get directly the base from the operation. |
| if (auto params = dyn_cast<ParametricOp>(op)) { |
| SymbolRefAttr symRef = params.getBaseType(); |
| Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef); |
| assert(defOp && "symbol reference should refer to an existing operation"); |
| paramIrdlOps.insert(defOp); |
| return false; |
| } |
| |
| // For `irdl.is`, we get the base TypeID directly. |
| if (auto is = dyn_cast<IsOp>(op)) { |
| Attribute expected = is.getExpected(); |
| isIds.insert(expected.getTypeID()); |
| return false; |
| } |
| |
| // For `irdl.any`, we return `false` since we can match any type or attribute |
| // base. |
| if (auto isA = dyn_cast<AnyOp>(op)) |
| return true; |
| |
| llvm_unreachable("unknown IRDL constraint"); |
| } |
| |
| /// Check that an any_of is in the subset IRDL can handle. |
| /// IRDL uses a greedy algorithm to match constraints. This means that if we |
| /// encounter an `any_of` with multiple constraints, we will match the first |
| /// constraint that is satisfied. Thus, the order of constraints matter in |
| /// `any_of` with our current algorithm. |
| /// In order to make the order of constraints irrelevant, we require that |
| /// all `any_of` constraint parameters are disjoint. For this, we check that |
| /// the base parameters are all disjoints between `parametric` operations, and |
| /// that they are disjoint between `parametric` and `is` operations. |
| /// This restriction will be relaxed in the future, when we will change our |
| /// algorithm to be non-greedy. |
| static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) { |
| SmallPtrSet<TypeID, 4> paramIds; |
| SmallPtrSet<Operation *, 4> paramIrdlOps; |
| SmallPtrSet<TypeID, 4> isIds; |
| |
| for (Value arg : anyOf.getArgs()) { |
| Operation *argOp = arg.getDefiningOp(); |
| SmallPtrSet<TypeID, 4> argParamIds; |
| SmallPtrSet<Operation *, 4> argParamIrdlOps; |
| SmallPtrSet<TypeID, 4> argIsIds; |
| |
| // Get the bases of this argument. If it can match any type or attribute, |
| // then our `any_of` should not be allowed. |
| if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds)) |
| return failure(); |
| |
| // We check that the base parameters are all disjoints between `parametric` |
| // operations, and that they are disjoint between `parametric` and `is` |
| // operations. |
| for (TypeID id : argParamIds) { |
| if (isIds.count(id)) |
| return failure(); |
| bool inserted = paramIds.insert(id).second; |
| if (!inserted) |
| return failure(); |
| } |
| |
| // We check that the base parameters are all disjoints with `irdl.is` |
| // operations. |
| for (TypeID id : isIds) { |
| if (paramIds.count(id)) |
| return failure(); |
| isIds.insert(id); |
| } |
| |
| // We check that all `parametric` operations are disjoint. We do not |
| // need to check that they are disjoint with `is` operations, since |
| // `is` operations cannot refer to attributes defined with `irdl.parametric` |
| // operations. |
| for (Operation *op : argParamIrdlOps) { |
| bool inserted = paramIrdlOps.insert(op).second; |
| if (!inserted) |
| return failure(); |
| } |
| } |
| |
| return success(); |
| } |
| |
| /// Load all dialects in the given module, without loading any operation, type |
| /// or attribute definitions. |
| static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) { |
| DenseMap<DialectOp, ExtensibleDialect *> dialects; |
| op.walk([&](DialectOp dialectOp) { |
| MLIRContext *ctx = dialectOp.getContext(); |
| StringRef dialectName = dialectOp.getName(); |
| |
| DynamicDialect *dialect = ctx->getOrLoadDynamicDialect( |
| dialectName, [](DynamicDialect *dialect) {}); |
| |
| dialects.insert({dialectOp, dialect}); |
| }); |
| return dialects; |
| } |
| |
| /// Preallocate type definitions objects with empty verifiers. |
| /// This in particular allocates a TypeID for each type definition. |
| static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> |
| preallocateTypeDefs(ModuleOp op, |
| DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
| DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs; |
| op.walk([&](TypeOp typeOp) { |
| ExtensibleDialect *dialect = dialects[typeOp.getParentOp()]; |
| auto typeDef = DynamicTypeDefinition::get( |
| typeOp.getName(), dialect, |
| [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
| return success(); |
| }); |
| typeDefs.try_emplace(typeOp, std::move(typeDef)); |
| }); |
| return typeDefs; |
| } |
| |
| /// Preallocate attribute definitions objects with empty verifiers. |
| /// This in particular allocates a TypeID for each attribute definition. |
| static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
| preallocateAttrDefs(ModuleOp op, |
| DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
| DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs; |
| op.walk([&](AttributeOp attrOp) { |
| ExtensibleDialect *dialect = dialects[attrOp.getParentOp()]; |
| auto attrDef = DynamicAttrDefinition::get( |
| attrOp.getName(), dialect, |
| [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
| return success(); |
| }); |
| attrDefs.try_emplace(attrOp, std::move(attrDef)); |
| }); |
| return attrDefs; |
| } |
| |
| LogicalResult mlir::irdl::loadDialects(ModuleOp op) { |
| // First, check that all any_of constraints are in a correct form. |
| // This is to ensure we can do the verification correctly. |
| WalkResult anyOfCorrects = op.walk( |
| [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); }); |
| if (anyOfCorrects.wasInterrupted()) |
| return op.emitError("any_of constraints are not in the correct form"); |
| |
| // Preallocate all dialects, and type and attribute definitions. |
| // In particular, this allocates TypeIDs so type and attributes can have |
| // verifiers that refer to each other. |
| DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op); |
| DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types = |
| preallocateTypeDefs(op, dialects); |
| DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs = |
| preallocateAttrDefs(op, dialects); |
| |
| // Set the verifier for types. |
| WalkResult res = op.walk([&](TypeOp typeOp) { |
| DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
| typeOp, dialects[typeOp.getParentOp()], types, attrs); |
| if (!verifier) |
| return WalkResult::interrupt(); |
| types[typeOp]->setVerifyFn(std::move(verifier)); |
| return WalkResult::advance(); |
| }); |
| if (res.wasInterrupted()) |
| return failure(); |
| |
| // Set the verifier for attributes. |
| res = op.walk([&](AttributeOp attrOp) { |
| DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
| attrOp, dialects[attrOp.getParentOp()], types, attrs); |
| if (!verifier) |
| return WalkResult::interrupt(); |
| attrs[attrOp]->setVerifyFn(std::move(verifier)); |
| return WalkResult::advance(); |
| }); |
| if (res.wasInterrupted()) |
| return failure(); |
| |
| // Define and load all operations. |
| res = op.walk([&](OperationOp opOp) { |
| return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs); |
| }); |
| if (res.wasInterrupted()) |
| return failure(); |
| |
| // Load all types in their dialects. |
| for (auto &pair : types) { |
| ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
| dialect->registerDynamicType(std::move(pair.second)); |
| } |
| |
| // Load all attributes in their dialects. |
| for (auto &pair : attrs) { |
| ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
| dialect->registerDynamicAttr(std::move(pair.second)); |
| } |
| |
| return success(); |
| } |