| //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Linalg/Passes.h" |
| |
| #include "PassDetail.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| using namespace mlir; |
| |
| static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { |
| if (!op->hasTrait<OpTrait::ElementwiseMappable>()) |
| return false; |
| |
| // TODO: The conversion pattern can be made to work for `any_of` here, but |
| // it's more complex as it requires tracking which operands are scalars. |
| return llvm::all_of(op->getOperandTypes(), |
| [](Type type) { return type.isa<RankedTensorType>(); }); |
| } |
| |
| namespace { |
| struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { |
| ConvertAnyElementwiseMappableOpOnRankedTensors() |
| : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const final { |
| if (!isElementwiseMappableOpOnRankedTensors(op)) |
| return rewriter.notifyMatchFailure( |
| op, "requires elementwise op on ranked tensors"); |
| |
| auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank(); |
| SmallVector<AffineMap, 3> indexingMaps( |
| op->getNumResults() + op->getNumOperands(), |
| rewriter.getMultiDimIdentityMap(rank)); |
| SmallVector<StringRef, 6> iteratorTypes(rank, |
| getParallelIteratorTypeName()); |
| rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
| op, /*resultTensorTypes=*/op->getResultTypes(), |
| /*inputs=*/op->getOperands(), |
| /*outputBuffers=*/ValueRange(), |
| /*initTensors=*/ValueRange(), |
| /*indexingMaps=*/indexingMaps, |
| /*iteratorTypes=*/iteratorTypes, |
| /*bodyBuilder=*/ |
| [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { |
| OperationState state(loc, op->getName()); |
| state.addAttributes(op->getAttrs()); |
| state.addOperands(regionArgs); |
| auto resultTypes = llvm::to_vector<6>( |
| llvm::map_range(op->getResultTypes(), [](Type type) { |
| return type.cast<TensorType>().getElementType(); |
| })); |
| state.addTypes(resultTypes); |
| auto *scalarOp = builder.createOperation(state); |
| builder.create<linalg::YieldOp>(loc, scalarOp->getResults()); |
| }); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::populateElementwiseToLinalgConversionPatterns( |
| OwningRewritePatternList &patterns, MLIRContext *) { |
| patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>(); |
| } |
| |
| namespace { |
| class ConvertElementwiseToLinalgPass |
| : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> { |
| |
| void runOnFunction() final { |
| auto func = getOperation(); |
| auto *context = &getContext(); |
| ConversionTarget target(*context); |
| OwningRewritePatternList patterns; |
| |
| populateElementwiseToLinalgConversionPatterns(patterns, context); |
| target.markUnknownOpDynamicallyLegal([](Operation *op) { |
| return !isElementwiseMappableOpOnRankedTensors(op); |
| }); |
| |
| if (failed(applyPartialConversion(func, target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| std::unique_ptr<OperationPass<FuncOp>> |
| mlir::createConvertElementwiseToLinalgPass() { |
| return std::make_unique<ConvertElementwiseToLinalgPass>(); |
| } |