| //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the dialect for the Toy IR: custom type parsing and |
| // operation verification. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "toy/Dialect.h" |
| |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| |
| using namespace mlir; |
| using namespace mlir::toy; |
| |
| //===----------------------------------------------------------------------===// |
| // ToyInlinerInterface |
| //===----------------------------------------------------------------------===// |
| |
| /// This class defines the interface for handling inlining with Toy |
| /// operations. |
| struct ToyInlinerInterface : public DialectInlinerInterface { |
| using DialectInlinerInterface::DialectInlinerInterface; |
| |
| //===--------------------------------------------------------------------===// |
| // Analysis Hooks |
| //===--------------------------------------------------------------------===// |
| |
| /// All operations within toy can be inlined. |
| bool isLegalToInline(Operation *, Region *, |
| BlockAndValueMapping &) const final { |
| return true; |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Transformation Hooks |
| //===--------------------------------------------------------------------===// |
| |
| /// Handle the given inlined terminator(toy.return) by replacing it with a new |
| /// operation as necessary. |
| void handleTerminator(Operation *op, |
| ArrayRef<Value> valuesToRepl) const final { |
| // Only "toy.return" needs to be handled here. |
| auto returnOp = cast<ReturnOp>(op); |
| |
| // Replace the values directly with the return operands. |
| assert(returnOp.getNumOperands() == valuesToRepl.size()); |
| for (const auto &it : llvm::enumerate(returnOp.getOperands())) |
| valuesToRepl[it.index()].replaceAllUsesWith(it.value()); |
| } |
| |
| /// Attempts to materialize a conversion for a type mismatch between a call |
| /// from this dialect, and a callable region. This method should generate an |
| /// operation that takes 'input' as the only operand, and produces a single |
| /// result of 'resultType'. If a conversion can not be generated, nullptr |
| /// should be returned. |
| Operation *materializeCallConversion(OpBuilder &builder, Value input, |
| Type resultType, |
| Location conversionLoc) const final { |
| return builder.create<CastOp>(conversionLoc, resultType, input); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ToyDialect |
| //===----------------------------------------------------------------------===// |
| |
| /// Dialect creation, the instance will be owned by the context. This is the |
| /// point of registration of custom types and operations for the dialect. |
| ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { |
| addOperations< |
| #define GET_OP_LIST |
| #include "toy/Ops.cpp.inc" |
| >(); |
| addInterfaces<ToyInlinerInterface>(); |
| addTypes<StructType>(); |
| } |
| |
| mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, |
| mlir::Attribute value, |
| mlir::Type type, |
| mlir::Location loc) { |
| if (type.isa<StructType>()) |
| return builder.create<StructConstantOp>(loc, type, |
| value.cast<mlir::ArrayAttr>()); |
| return builder.create<ConstantOp>(loc, type, |
| value.cast<mlir::DenseElementsAttr>()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Toy Operations |
| //===----------------------------------------------------------------------===// |
| |
| /// A generalized parser for binary operations. This parses the different forms |
| /// of 'printBinaryOp' below. |
| static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, |
| mlir::OperationState &result) { |
| SmallVector<mlir::OpAsmParser::OperandType, 2> operands; |
| llvm::SMLoc operandsLoc = parser.getCurrentLocation(); |
| Type type; |
| if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type)) |
| return mlir::failure(); |
| |
| // If the type is a function type, it contains the input and result types of |
| // this operation. |
| if (FunctionType funcType = type.dyn_cast<FunctionType>()) { |
| if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, |
| result.operands)) |
| return mlir::failure(); |
| result.addTypes(funcType.getResults()); |
| return mlir::success(); |
| } |
| |
| // Otherwise, the parsed type is the type of both operands and results. |
| if (parser.resolveOperands(operands, type, result.operands)) |
| return mlir::failure(); |
| result.addTypes(type); |
| return mlir::success(); |
| } |
| |
| /// A generalized printer for binary operations. It prints in two different |
| /// forms depending on if all of the types match. |
| static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { |
| printer << op->getName() << " " << op->getOperands(); |
| printer.printOptionalAttrDict(op->getAttrs()); |
| printer << " : "; |
| |
| // If all of the types are the same, print the type directly. |
| Type resultType = *op->result_type_begin(); |
| if (llvm::all_of(op->getOperandTypes(), |
| [=](Type type) { return type == resultType; })) { |
| printer << resultType; |
| return; |
| } |
| |
| // Otherwise, print a functional type. |
| printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConstantOp |
| |
| /// Build a constant operation. |
| /// The builder is passed as an argument, so is the state that this method is |
| /// expected to fill in order to build the operation. |
| void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, |
| double value) { |
| auto dataType = RankedTensorType::get({}, builder.getF64Type()); |
| auto dataAttribute = DenseElementsAttr::get(dataType, value); |
| ConstantOp::build(builder, state, dataType, dataAttribute); |
| } |
| |
| /// The 'OpAsmParser' class provides a collection of methods for parsing |
| /// various punctuation, as well as attributes, operands, types, etc. Each of |
| /// these methods returns a `ParseResult`. This class is a wrapper around |
| /// `LogicalResult` that can be converted to a boolean `true` value on failure, |
| /// or `false` on success. This allows for easily chaining together a set of |
| /// parser rules. These rules are used to populate an `mlir::OperationState` |
| /// similarly to the `build` methods described above. |
| static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser, |
| mlir::OperationState &result) { |
| mlir::DenseElementsAttr value; |
| if (parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseAttribute(value, "value", result.attributes)) |
| return failure(); |
| |
| result.addTypes(value.getType()); |
| return success(); |
| } |
| |
| /// The 'OpAsmPrinter' class is a stream that allows for formatting |
| /// strings, attributes, operands, types, etc. |
| static void print(mlir::OpAsmPrinter &printer, ConstantOp op) { |
| printer << "toy.constant "; |
| printer.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); |
| printer << op.value(); |
| } |
| |
| /// Verify that the given attribute value is valid for the given type. |
| static mlir::LogicalResult verifyConstantForType(mlir::Type type, |
| mlir::Attribute opaqueValue, |
| mlir::Operation *op) { |
| if (type.isa<mlir::TensorType>()) { |
| // Check that the value is an elements attribute. |
| auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>(); |
| if (!attrValue) |
| return op->emitError("constant of TensorType must be initialized by " |
| "a DenseFPElementsAttr, got ") |
| << opaqueValue; |
| |
| // If the return type of the constant is not an unranked tensor, the shape |
| // must match the shape of the attribute holding the data. |
| auto resultType = type.dyn_cast<mlir::RankedTensorType>(); |
| if (!resultType) |
| return success(); |
| |
| // Check that the rank of the attribute type matches the rank of the |
| // constant result type. |
| auto attrType = attrValue.getType().cast<mlir::TensorType>(); |
| if (attrType.getRank() != resultType.getRank()) { |
| return op->emitOpError("return type must match the one of the attached " |
| "value attribute: ") |
| << attrType.getRank() << " != " << resultType.getRank(); |
| } |
| |
| // Check that each of the dimensions match between the two types. |
| for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { |
| if (attrType.getShape()[dim] != resultType.getShape()[dim]) { |
| return op->emitOpError( |
| "return type shape mismatches its attribute at dimension ") |
| << dim << ": " << attrType.getShape()[dim] |
| << " != " << resultType.getShape()[dim]; |
| } |
| } |
| return mlir::success(); |
| } |
| auto resultType = type.cast<StructType>(); |
| llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes(); |
| |
| // Verify that the initializer is an Array. |
| auto attrValue = opaqueValue.dyn_cast<ArrayAttr>(); |
| if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) |
| return op->emitError("constant of StructType must be initialized by an " |
| "ArrayAttr with the same number of elements, got ") |
| << opaqueValue; |
| |
| // Check that each of the elements are valid. |
| llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue(); |
| for (const auto &it : llvm::zip(resultElementTypes, attrElementValues)) |
| if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) |
| return mlir::failure(); |
| return mlir::success(); |
| } |
| |
| /// Verifier for the constant operation. This corresponds to the `::verify(...)` |
| /// in the op definition. |
| static mlir::LogicalResult verify(ConstantOp op) { |
| return verifyConstantForType(op.getResult().getType(), op.value(), op); |
| } |
| |
| static mlir::LogicalResult verify(StructConstantOp op) { |
| return verifyConstantForType(op.getResult().getType(), op.value(), op); |
| } |
| |
| /// Infer the output shape of the ConstantOp, this is required by the shape |
| /// inference interface. |
| void ConstantOp::inferShapes() { getResult().setType(value().getType()); } |
| |
| //===----------------------------------------------------------------------===// |
| // AddOp |
| |
| void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, |
| mlir::Value lhs, mlir::Value rhs) { |
| state.addTypes(UnrankedTensorType::get(builder.getF64Type())); |
| state.addOperands({lhs, rhs}); |
| } |
| |
| /// Infer the output shape of the AddOp, this is required by the shape inference |
| /// interface. |
| void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); } |
| |
| //===----------------------------------------------------------------------===// |
| // CastOp |
| |
| /// Infer the output shape of the CastOp, this is required by the shape |
| /// inference interface. |
| void CastOp::inferShapes() { getResult().setType(getOperand().getType()); } |
| |
| //===----------------------------------------------------------------------===// |
| // GenericCallOp |
| |
| void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, |
| StringRef callee, ArrayRef<mlir::Value> arguments) { |
| // Generic call always returns an unranked Tensor initially. |
| state.addTypes(UnrankedTensorType::get(builder.getF64Type())); |
| state.addOperands(arguments); |
| state.addAttribute("callee", builder.getSymbolRefAttr(callee)); |
| } |
| |
| /// Return the callee of the generic call operation, this is required by the |
| /// call interface. |
| CallInterfaceCallable GenericCallOp::getCallableForCallee() { |
| return getAttrOfType<SymbolRefAttr>("callee"); |
| } |
| |
| /// Get the argument operands to the called function, this is required by the |
| /// call interface. |
| Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } |
| |
| //===----------------------------------------------------------------------===// |
| // MulOp |
| |
| void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, |
| mlir::Value lhs, mlir::Value rhs) { |
| state.addTypes(UnrankedTensorType::get(builder.getF64Type())); |
| state.addOperands({lhs, rhs}); |
| } |
| |
| /// Infer the output shape of the MulOp, this is required by the shape inference |
| /// interface. |
| void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); } |
| |
| //===----------------------------------------------------------------------===// |
| // ReturnOp |
| |
| static mlir::LogicalResult verify(ReturnOp op) { |
| // We know that the parent operation is a function, because of the 'HasParent' |
| // trait attached to the operation definition. |
| auto function = cast<FuncOp>(op.getParentOp()); |
| |
| /// ReturnOps can only have a single optional operand. |
| if (op.getNumOperands() > 1) |
| return op.emitOpError() << "expects at most 1 return operand"; |
| |
| // The operand number and types must match the function signature. |
| const auto &results = function.getType().getResults(); |
| if (op.getNumOperands() != results.size()) |
| return op.emitOpError() |
| << "does not return the same number of values (" |
| << op.getNumOperands() << ") as the enclosing function (" |
| << results.size() << ")"; |
| |
| // If the operation does not have an input, we are done. |
| if (!op.hasOperand()) |
| return mlir::success(); |
| |
| auto inputType = *op.operand_type_begin(); |
| auto resultType = results.front(); |
| |
| // Check that the result type of the function matches the operand type. |
| if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() || |
| resultType.isa<mlir::UnrankedTensorType>()) |
| return mlir::success(); |
| |
| return op.emitError() << "type of return operand (" << inputType |
| << ") doesn't match function result type (" |
| << resultType << ")"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StructAccessOp |
| |
| void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, |
| mlir::Value input, size_t index) { |
| // Extract the result type from the input type. |
| StructType structTy = input.getType().cast<StructType>(); |
| assert(index < structTy.getNumElementTypes()); |
| mlir::Type resultType = structTy.getElementTypes()[index]; |
| |
| // Call into the auto-generated build method. |
| build(b, state, resultType, input, b.getI64IntegerAttr(index)); |
| } |
| |
| static mlir::LogicalResult verify(StructAccessOp op) { |
| StructType structTy = op.input().getType().cast<StructType>(); |
| size_t index = op.index().getZExtValue(); |
| if (index >= structTy.getNumElementTypes()) |
| return op.emitOpError() |
| << "index should be within the range of the input struct type"; |
| mlir::Type resultType = op.getResult().getType(); |
| if (resultType != structTy.getElementTypes()[index]) |
| return op.emitOpError() << "must have the same result type as the struct " |
| "element referred to by the index"; |
| return mlir::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TransposeOp |
| |
| void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, |
| mlir::Value value) { |
| state.addTypes(UnrankedTensorType::get(builder.getF64Type())); |
| state.addOperands(value); |
| } |
| |
| void TransposeOp::inferShapes() { |
| auto arrayTy = getOperand().getType().cast<RankedTensorType>(); |
| SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); |
| getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); |
| } |
| |
| static mlir::LogicalResult verify(TransposeOp op) { |
| auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>(); |
| auto resultType = op.getType().dyn_cast<RankedTensorType>(); |
| if (!inputType || !resultType) |
| return mlir::success(); |
| |
| auto inputShape = inputType.getShape(); |
| if (!std::equal(inputShape.begin(), inputShape.end(), |
| resultType.getShape().rbegin())) { |
| return op.emitError() |
| << "expected result shape to be a transpose of the input"; |
| } |
| return mlir::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Toy Types |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| namespace toy { |
| namespace detail { |
| /// This class represents the internal storage of the Toy `StructType`. |
| struct StructTypeStorage : public mlir::TypeStorage { |
| /// The `KeyTy` is a required type that provides an interface for the storage |
| /// instance. This type will be used when uniquing an instance of the type |
| /// storage. For our struct type, we will unique each instance structurally on |
| /// the elements that it contains. |
| using KeyTy = llvm::ArrayRef<mlir::Type>; |
| |
| /// A constructor for the type storage instance. |
| StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes) |
| : elementTypes(elementTypes) {} |
| |
| /// Define the comparison function for the key type with the current storage |
| /// instance. This is used when constructing a new instance to ensure that we |
| /// haven't already uniqued an instance of the given key. |
| bool operator==(const KeyTy &key) const { return key == elementTypes; } |
| |
| /// Define a hash function for the key type. This is used when uniquing |
| /// instances of the storage, see the `StructType::get` method. |
| /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type |
| /// have hash functions available, so we could just omit this entirely. |
| static llvm::hash_code hashKey(const KeyTy &key) { |
| return llvm::hash_value(key); |
| } |
| |
| /// Define a construction function for the key type from a set of parameters. |
| /// These parameters will be provided when constructing the storage instance |
| /// itself. |
| /// Note: This method isn't necessary because KeyTy can be directly |
| /// constructed with the given parameters. |
| static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) { |
| return KeyTy(elementTypes); |
| } |
| |
| /// Define a construction method for creating a new instance of this storage. |
| /// This method takes an instance of a storage allocator, and an instance of a |
| /// `KeyTy`. The given allocator must be used for *all* necessary dynamic |
| /// allocations used to create the type storage and its internal. |
| static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, |
| const KeyTy &key) { |
| // Copy the elements from the provided `KeyTy` into the allocator. |
| llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key); |
| |
| // Allocate the storage instance and construct it. |
| return new (allocator.allocate<StructTypeStorage>()) |
| StructTypeStorage(elementTypes); |
| } |
| |
| /// The following field contains the element types of the struct. |
| llvm::ArrayRef<mlir::Type> elementTypes; |
| }; |
| } // end namespace detail |
| } // end namespace toy |
| } // end namespace mlir |
| |
| /// Create an instance of a `StructType` with the given element types. There |
| /// *must* be at least one element type. |
| StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) { |
| assert(!elementTypes.empty() && "expected at least 1 element type"); |
| |
| // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance |
| // of this type. The first two parameters are the context to unique in and the |
| // kind of the type. The parameters after the type kind are forwarded to the |
| // storage instance. |
| mlir::MLIRContext *ctx = elementTypes.front().getContext(); |
| return Base::get(ctx, ToyTypes::Struct, elementTypes); |
| } |
| |
| /// Returns the element types of this struct type. |
| llvm::ArrayRef<mlir::Type> StructType::getElementTypes() { |
| // 'getImpl' returns a pointer to the internal storage instance. |
| return getImpl()->elementTypes; |
| } |
| |
| /// Parse an instance of a type registered to the toy dialect. |
| mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { |
| // Parse a struct type in the following form: |
| // struct-type ::= `struct` `<` type (`,` type)* `>` |
| |
| // NOTE: All MLIR parser function return a ParseResult. This is a |
| // specialization of LogicalResult that auto-converts to a `true` boolean |
| // value on failure to allow for chaining, but may be used with explicit |
| // `mlir::failed/mlir::succeeded` as desired. |
| |
| // Parse: `struct` `<` |
| if (parser.parseKeyword("struct") || parser.parseLess()) |
| return Type(); |
| |
| // Parse the element types of the struct. |
| SmallVector<mlir::Type, 1> elementTypes; |
| do { |
| // Parse the current element type. |
| llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
| mlir::Type elementType; |
| if (parser.parseType(elementType)) |
| return nullptr; |
| |
| // Check that the type is either a TensorType or another StructType. |
| if (!elementType.isa<mlir::TensorType, StructType>()) { |
| parser.emitError(typeLoc, "element type for a struct must either " |
| "be a TensorType or a StructType, got: ") |
| << elementType; |
| return Type(); |
| } |
| elementTypes.push_back(elementType); |
| |
| // Parse the optional: `,` |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| // Parse: `>` |
| if (parser.parseGreater()) |
| return Type(); |
| return StructType::get(elementTypes); |
| } |
| |
| /// Print an instance of a type registered to the toy dialect. |
| void ToyDialect::printType(mlir::Type type, |
| mlir::DialectAsmPrinter &printer) const { |
| // Currently the only toy type is a struct type. |
| StructType structType = type.cast<StructType>(); |
| |
| // Print the struct type according to the parser format. |
| printer << "struct<"; |
| llvm::interleaveComma(structType.getElementTypes(), printer); |
| printer << '>'; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TableGen'd op method definitions |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "toy/Ops.cpp.inc" |