| //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| |
| #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Interfaces/FunctionImplementation.h" |
| #include "mlir/Support/MathExtras.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include <numeric> |
| |
| #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" |
| |
| using namespace mlir; |
| using namespace mlir::func; |
| |
| //===----------------------------------------------------------------------===// |
| // FuncDialect |
| //===----------------------------------------------------------------------===// |
| |
| void FuncDialect::initialize() { |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |
| >(); |
| declarePromisedInterface<DialectInlinerInterface, FuncDialect>(); |
| declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>(); |
| declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp, |
| FuncOp, ReturnOp>(); |
| } |
| |
| /// Materialize a single constant operation from a given attribute value with |
| /// the desired resultant type. |
| Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, |
| Type type, Location loc) { |
| if (ConstantOp::isBuildableWith(value, type)) |
| return builder.create<ConstantOp>(loc, type, |
| llvm::cast<FlatSymbolRefAttr>(value)); |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| // Check that the callee attribute was specified. |
| auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); |
| if (!fnAttr) |
| return emitOpError("requires a 'callee' symbol reference attribute"); |
| FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
| if (!fn) |
| return emitOpError() << "'" << fnAttr.getValue() |
| << "' does not reference a valid function"; |
| |
| // Verify that the operand and result types match the callee. |
| auto fnType = fn.getFunctionType(); |
| if (fnType.getNumInputs() != getNumOperands()) |
| return emitOpError("incorrect number of operands for callee"); |
| |
| for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
| if (getOperand(i).getType() != fnType.getInput(i)) |
| return emitOpError("operand type mismatch: expected operand type ") |
| << fnType.getInput(i) << ", but provided " |
| << getOperand(i).getType() << " for operand number " << i; |
| |
| if (fnType.getNumResults() != getNumResults()) |
| return emitOpError("incorrect number of results for callee"); |
| |
| for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
| if (getResult(i).getType() != fnType.getResult(i)) { |
| auto diag = emitOpError("result type mismatch at index ") << i; |
| diag.attachNote() << " op result types: " << getResultTypes(); |
| diag.attachNote() << "function result types: " << fnType.getResults(); |
| return diag; |
| } |
| |
| return success(); |
| } |
| |
| FunctionType CallOp::getCalleeType() { |
| return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallIndirectOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Fold indirect calls that have a constant function as the callee operand. |
| LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, |
| PatternRewriter &rewriter) { |
| // Check that the callee is a constant callee. |
| SymbolRefAttr calledFn; |
| if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) |
| return failure(); |
| |
| // Replace with a direct call. |
| rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, |
| indirectCall.getResultTypes(), |
| indirectCall.getArgOperands()); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConstantOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConstantOp::verify() { |
| StringRef fnName = getValue(); |
| Type type = getType(); |
| |
| // Try to find the referenced function. |
| auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName); |
| if (!fn) |
| return emitOpError() << "reference to undefined function '" << fnName |
| << "'"; |
| |
| // Check that the referenced function has the correct type. |
| if (fn.getFunctionType() != type) |
| return emitOpError("reference to function with mismatched type"); |
| |
| return success(); |
| } |
| |
| OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { |
| return getValueAttr(); |
| } |
| |
| void ConstantOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "f"); |
| } |
| |
| bool ConstantOp::isBuildableWith(Attribute value, Type type) { |
| return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FuncOp |
| //===----------------------------------------------------------------------===// |
| |
| FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
| ArrayRef<NamedAttribute> attrs) { |
| OpBuilder builder(location->getContext()); |
| OperationState state(location, getOperationName()); |
| FuncOp::build(builder, state, name, type, attrs); |
| return cast<FuncOp>(Operation::create(state)); |
| } |
| FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
| Operation::dialect_attr_range attrs) { |
| SmallVector<NamedAttribute, 8> attrRef(attrs); |
| return create(location, name, type, llvm::ArrayRef(attrRef)); |
| } |
| FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<DictionaryAttr> argAttrs) { |
| FuncOp func = create(location, name, type, attrs); |
| func.setAllArgAttrs(argAttrs); |
| return func; |
| } |
| |
| void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
| FunctionType type, ArrayRef<NamedAttribute> attrs, |
| ArrayRef<DictionaryAttr> argAttrs) { |
| state.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
| state.attributes.append(attrs.begin(), attrs.end()); |
| state.addRegion(); |
| |
| if (argAttrs.empty()) |
| return; |
| assert(type.getNumInputs() == argAttrs.size()); |
| function_interface_impl::addArgAndResultAttrs( |
| builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
| getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
| } |
| |
| ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
| auto buildFuncType = |
| [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
| function_interface_impl::VariadicFlag, |
| std::string &) { return builder.getFunctionType(argTypes, results); }; |
| |
| return function_interface_impl::parseFunctionOp( |
| parser, result, /*allowVariadic=*/false, |
| getFunctionTypeAttrName(result.name), buildFuncType, |
| getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| } |
| |
| void FuncOp::print(OpAsmPrinter &p) { |
| function_interface_impl::printFunctionOp( |
| p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
| getArgAttrsAttrName(), getResAttrsAttrName()); |
| } |
| |
| /// Clone the internal blocks from this function into dest and all attributes |
| /// from this function to dest. |
| void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { |
| // Add the attributes of this function to dest. |
| llvm::MapVector<StringAttr, Attribute> newAttrMap; |
| for (const auto &attr : dest->getAttrs()) |
| newAttrMap.insert({attr.getName(), attr.getValue()}); |
| for (const auto &attr : (*this)->getAttrs()) |
| newAttrMap.insert({attr.getName(), attr.getValue()}); |
| |
| auto newAttrs = llvm::to_vector(llvm::map_range( |
| newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { |
| return NamedAttribute(attrPair.first, attrPair.second); |
| })); |
| dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); |
| |
| // Clone the body. |
| getBody().cloneInto(&dest.getBody(), mapper); |
| } |
| |
| /// Create a deep copy of this function and all of its blocks, remapping |
| /// any operands that use values outside of the function using the map that is |
| /// provided (leaving them alone if no entry is present). Replaces references |
| /// to cloned sub-values with the corresponding value that is copied, and adds |
| /// those mappings to the mapper. |
| FuncOp FuncOp::clone(IRMapping &mapper) { |
| // Create the new function. |
| FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); |
| |
| // If the function has a body, then the user might be deleting arguments to |
| // the function by specifying them in the mapper. If so, we don't add the |
| // argument to the input type vector. |
| if (!isExternal()) { |
| FunctionType oldType = getFunctionType(); |
| |
| unsigned oldNumArgs = oldType.getNumInputs(); |
| SmallVector<Type, 4> newInputs; |
| newInputs.reserve(oldNumArgs); |
| for (unsigned i = 0; i != oldNumArgs; ++i) |
| if (!mapper.contains(getArgument(i))) |
| newInputs.push_back(oldType.getInput(i)); |
| |
| /// If any of the arguments were dropped, update the type and drop any |
| /// necessary argument attributes. |
| if (newInputs.size() != oldNumArgs) { |
| newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, |
| oldType.getResults())); |
| |
| if (ArrayAttr argAttrs = getAllArgAttrs()) { |
| SmallVector<Attribute> newArgAttrs; |
| newArgAttrs.reserve(newInputs.size()); |
| for (unsigned i = 0; i != oldNumArgs; ++i) |
| if (!mapper.contains(getArgument(i))) |
| newArgAttrs.push_back(argAttrs[i]); |
| newFunc.setAllArgAttrs(newArgAttrs); |
| } |
| } |
| } |
| |
| /// Clone the current function into the new one and return it. |
| cloneInto(newFunc, mapper); |
| return newFunc; |
| } |
| FuncOp FuncOp::clone() { |
| IRMapping mapper; |
| return clone(mapper); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReturnOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReturnOp::verify() { |
| auto function = cast<FuncOp>((*this)->getParentOp()); |
| |
| // The operand number and types must match the function signature. |
| const auto &results = function.getFunctionType().getResults(); |
| if (getNumOperands() != results.size()) |
| return emitOpError("has ") |
| << getNumOperands() << " operands, but enclosing function (@" |
| << function.getName() << ") returns " << results.size(); |
| |
| for (unsigned i = 0, e = results.size(); i != e; ++i) |
| if (getOperand(i).getType() != results[i]) |
| return emitError() << "type of return operand " << i << " (" |
| << getOperand(i).getType() |
| << ") doesn't match function result type (" |
| << results[i] << ")" |
| << " in function @" << function.getName(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TableGen'd op method definitions |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |