| //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "TestDialect.h" |
| #include "TestOps.h" |
| #include "TestTypes.h" |
| #include "mlir/Bytecode/BytecodeImplementation.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/AsmState.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/ExtensibleDialect.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/ODSSupport.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Verifier.h" |
| #include "mlir/Interfaces/CallInterfaces.h" |
| #include "mlir/Interfaces/FunctionImplementation.h" |
| #include "mlir/Interfaces/InferIntRangeInterface.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/FoldUtils.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| #include "llvm/ADT/STLFunctionalExtras.h" |
| #include "llvm/ADT/SmallString.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/Support/Base64.h" |
| #include "llvm/Support/Casting.h" |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/DLTI/DLTI.h" |
| #include "mlir/Interfaces/FoldInterfaces.h" |
| #include "mlir/Reducer/ReductionPatternInterface.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| #include <cstdint> |
| #include <numeric> |
| #include <optional> |
| |
| // Include this before the using namespace lines below to test that we don't |
| // have namespace dependencies. |
| #include "TestOpsDialect.cpp.inc" |
| |
| using namespace mlir; |
| using namespace test; |
| |
| //===----------------------------------------------------------------------===// |
| // PropertiesWithCustomPrint |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, |
| Attribute attr, |
| function_ref<InFlightDiagnostic()> emitError) { |
| DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr); |
| if (!dict) { |
| emitError() << "expected DictionaryAttr to set TestProperties"; |
| return failure(); |
| } |
| auto label = dict.getAs<mlir::StringAttr>("label"); |
| if (!label) { |
| emitError() << "expected StringAttr for key `label`"; |
| return failure(); |
| } |
| auto valueAttr = dict.getAs<IntegerAttr>("value"); |
| if (!valueAttr) { |
| emitError() << "expected IntegerAttr for key `value`"; |
| return failure(); |
| } |
| |
| prop.label = std::make_shared<std::string>(label.getValue()); |
| prop.value = valueAttr.getValue().getSExtValue(); |
| return success(); |
| } |
| |
| DictionaryAttr |
| test::getPropertiesAsAttribute(MLIRContext *ctx, |
| const PropertiesWithCustomPrint &prop) { |
| SmallVector<NamedAttribute> attrs; |
| Builder b{ctx}; |
| attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label))); |
| attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value))); |
| return b.getDictionaryAttr(attrs); |
| } |
| |
| llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) { |
| return llvm::hash_combine(prop.value, StringRef(*prop.label)); |
| } |
| |
| void test::customPrintProperties(OpAsmPrinter &p, |
| const PropertiesWithCustomPrint &prop) { |
| p.printKeywordOrString(*prop.label); |
| p << " is " << prop.value; |
| } |
| |
| ParseResult test::customParseProperties(OpAsmParser &parser, |
| PropertiesWithCustomPrint &prop) { |
| std::string label; |
| if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") || |
| parser.parseInteger(prop.value)) |
| return failure(); |
| prop.label = std::make_shared<std::string>(std::move(label)); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MyPropStruct |
| //===----------------------------------------------------------------------===// |
| |
| Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const { |
| return StringAttr::get(ctx, content); |
| } |
| |
| LogicalResult |
| MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr, |
| function_ref<InFlightDiagnostic()> emitError) { |
| StringAttr strAttr = dyn_cast<StringAttr>(attr); |
| if (!strAttr) { |
| emitError() << "Expect StringAttr but got " << attr; |
| return failure(); |
| } |
| prop.content = strAttr.getValue(); |
| return success(); |
| } |
| |
| llvm::hash_code MyPropStruct::hash() const { |
| return hash_value(StringRef(content)); |
| } |
| |
| LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader, |
| MyPropStruct &prop) { |
| StringRef str; |
| if (failed(reader.readString(str))) |
| return failure(); |
| prop.content = str.str(); |
| return success(); |
| } |
| |
| void test::writeToMlirBytecode(DialectBytecodeWriter &writer, |
| MyPropStruct &prop) { |
| writer.writeOwnedString(prop.content); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // VersionedProperties |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, |
| function_ref<InFlightDiagnostic()> emitError) { |
| DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr); |
| if (!dict) { |
| emitError() << "expected DictionaryAttr to set VersionedProperties"; |
| return failure(); |
| } |
| auto value1Attr = dict.getAs<IntegerAttr>("value1"); |
| if (!value1Attr) { |
| emitError() << "expected IntegerAttr for key `value1`"; |
| return failure(); |
| } |
| auto value2Attr = dict.getAs<IntegerAttr>("value2"); |
| if (!value2Attr) { |
| emitError() << "expected IntegerAttr for key `value2`"; |
| return failure(); |
| } |
| |
| prop.value1 = value1Attr.getValue().getSExtValue(); |
| prop.value2 = value2Attr.getValue().getSExtValue(); |
| return success(); |
| } |
| |
| DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx, |
| const VersionedProperties &prop) { |
| SmallVector<NamedAttribute> attrs; |
| Builder b{ctx}; |
| attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1))); |
| attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2))); |
| return b.getDictionaryAttr(attrs); |
| } |
| |
| llvm::hash_code test::computeHash(const VersionedProperties &prop) { |
| return llvm::hash_combine(prop.value1, prop.value2); |
| } |
| |
| void test::customPrintProperties(OpAsmPrinter &p, |
| const VersionedProperties &prop) { |
| p << prop.value1 << " | " << prop.value2; |
| } |
| |
| ParseResult test::customParseProperties(OpAsmParser &parser, |
| VersionedProperties &prop) { |
| if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() || |
| parser.parseInteger(prop.value2)) |
| return failure(); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Bytecode Support |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader, |
| MutableArrayRef<int64_t> prop) { |
| uint64_t size; |
| if (failed(reader.readVarInt(size))) |
| return failure(); |
| if (size != prop.size()) |
| return reader.emitError("array size mismach when reading properties: ") |
| << size << " vs expected " << prop.size(); |
| for (auto &elt : prop) { |
| uint64_t value; |
| if (failed(reader.readVarInt(value))) |
| return failure(); |
| elt = value; |
| } |
| return success(); |
| } |
| |
| void test::writeToMlirBytecode(DialectBytecodeWriter &writer, |
| ArrayRef<int64_t> prop) { |
| writer.writeVarInt(prop.size()); |
| for (auto elt : prop) |
| writer.writeVarInt(elt); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Dynamic operations |
| //===----------------------------------------------------------------------===// |
| |
| std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) { |
| return DynamicOpDefinition::get( |
| "dynamic_generic", dialect, [](Operation *op) { return success(); }, |
| [](Operation *op) { return success(); }); |
| } |
| |
| std::unique_ptr<DynamicOpDefinition> |
| getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { |
| return DynamicOpDefinition::get( |
| "dynamic_one_operand_two_results", dialect, |
| [](Operation *op) { |
| if (op->getNumOperands() != 1) { |
| op->emitOpError() |
| << "expected 1 operand, but had " << op->getNumOperands(); |
| return failure(); |
| } |
| if (op->getNumResults() != 2) { |
| op->emitOpError() |
| << "expected 2 results, but had " << op->getNumResults(); |
| return failure(); |
| } |
| return success(); |
| }, |
| [](Operation *op) { return success(); }); |
| } |
| |
| std::unique_ptr<DynamicOpDefinition> |
| getDynamicCustomParserPrinterOp(TestDialect *dialect) { |
| auto verifier = [](Operation *op) { |
| if (op->getNumOperands() == 0 && op->getNumResults() == 0) |
| return success(); |
| op->emitError() << "operation should have no operands and no results"; |
| return failure(); |
| }; |
| auto regionVerifier = [](Operation *op) { return success(); }; |
| |
| auto parser = [](OpAsmParser &parser, OperationState &state) { |
| return parser.parseKeyword("custom_keyword"); |
| }; |
| |
| auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { |
| printer << op->getName() << " custom_keyword"; |
| }; |
| |
| return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect, |
| verifier, regionVerifier, parser, printer); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestDialect |
| //===----------------------------------------------------------------------===// |
| |
| void test::registerTestDialect(DialectRegistry ®istry) { |
| registry.insert<TestDialect>(); |
| } |
| |
| void test::testSideEffectOpGetEffect( |
| Operation *op, |
| SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> |
| &effects) { |
| auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); |
| if (!effectsAttr) |
| return; |
| |
| effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); |
| } |
| |
| // This is the implementation of a dialect fallback for `TestEffectOpInterface`. |
| struct TestOpEffectInterfaceFallback |
| : public TestEffectOpInterface::FallbackModel< |
| TestOpEffectInterfaceFallback> { |
| static bool classof(Operation *op) { |
| bool isSupportedOp = |
| op->getName().getStringRef() == "test.unregistered_side_effect_op"; |
| assert(isSupportedOp && "Unexpected dispatch"); |
| return isSupportedOp; |
| } |
| |
| void |
| getEffects(Operation *op, |
| SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> |
| &effects) const { |
| testSideEffectOpGetEffect(op, effects); |
| } |
| }; |
| |
| void TestDialect::initialize() { |
| registerAttributes(); |
| registerTypes(); |
| registerOpsSyntax(); |
| addOperations<ManualCppOpWithFold>(); |
| registerTestDialectOperations(this); |
| registerDynamicOp(getDynamicGenericOp(this)); |
| registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); |
| registerDynamicOp(getDynamicCustomParserPrinterOp(this)); |
| registerInterfaces(); |
| allowUnknownOperations(); |
| |
| // Instantiate our fallback op interface that we'll use on specific |
| // unregistered op. |
| fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; |
| } |
| |
| TestDialect::~TestDialect() { |
| delete static_cast<TestOpEffectInterfaceFallback *>( |
| fallbackEffectOpInterfaces); |
| } |
| |
| Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, |
| Type type, Location loc) { |
| return builder.create<TestOpConstant>(loc, type, value); |
| } |
| |
| void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, |
| OperationName opName) { |
| if (opName.getIdentifier() == "test.unregistered_side_effect_op" && |
| typeID == TypeID::get<TestEffectOpInterface>()) |
| return fallbackEffectOpInterfaces; |
| return nullptr; |
| } |
| |
| LogicalResult TestDialect::verifyOperationAttribute(Operation *op, |
| NamedAttribute namedAttr) { |
| if (namedAttr.getName() == "test.invalid_attr") |
| return op->emitError() << "invalid to use 'test.invalid_attr'"; |
| return success(); |
| } |
| |
| LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, |
| unsigned regionIndex, |
| unsigned argIndex, |
| NamedAttribute namedAttr) { |
| if (namedAttr.getName() == "test.invalid_attr") |
| return op->emitError() << "invalid to use 'test.invalid_attr'"; |
| return success(); |
| } |
| |
| LogicalResult |
| TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, |
| unsigned resultIndex, |
| NamedAttribute namedAttr) { |
| if (namedAttr.getName() == "test.invalid_attr") |
| return op->emitError() << "invalid to use 'test.invalid_attr'"; |
| return success(); |
| } |
| |
| std::optional<Dialect::ParseOpHook> |
| TestDialect::getParseOperationHook(StringRef opName) const { |
| if (opName == "test.dialect_custom_printer") { |
| return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { |
| return parser.parseKeyword("custom_format"); |
| }}; |
| } |
| if (opName == "test.dialect_custom_format_fallback") { |
| return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { |
| return parser.parseKeyword("custom_format_fallback"); |
| }}; |
| } |
| if (opName == "test.dialect_custom_printer.with.dot") { |
| return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { |
| return ParseResult::success(); |
| }}; |
| } |
| return std::nullopt; |
| } |
| |
| llvm::unique_function<void(Operation *, OpAsmPrinter &)> |
| TestDialect::getOperationPrinter(Operation *op) const { |
| StringRef opName = op->getName().getStringRef(); |
| if (opName == "test.dialect_custom_printer") { |
| return [](Operation *op, OpAsmPrinter &printer) { |
| printer.getStream() << " custom_format"; |
| }; |
| } |
| if (opName == "test.dialect_custom_format_fallback") { |
| return [](Operation *op, OpAsmPrinter &printer) { |
| printer.getStream() << " custom_format_fallback"; |
| }; |
| } |
| return {}; |
| } |
| |
| static LogicalResult |
| dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, |
| PatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| op, rewriter.getI32IntegerAttr(42)); |
| return success(); |
| } |
| |
| void TestDialect::getCanonicalizationPatterns( |
| RewritePatternSet &results) const { |
| results.add(&dialectCanonicalizationPattern); |
| } |