|  | //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// | 
|  | // | 
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | 
|  |  | 
|  | #include "mlir/Dialect/Affine/IR/AffineOps.h" | 
|  | #include "mlir/Dialect/Linalg/IR/Linalg.h" | 
|  | #include <optional> | 
|  |  | 
|  | using namespace mlir; | 
|  | using namespace mlir::linalg; | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | /// Pattern to decompose a GenericOp that has more than two statements | 
|  | /// into one GenericOp with the first statement (i.e. peeled operation), and | 
|  | /// a second GenericOp with the remaining statements (i.e. residual operations). | 
|  |  | 
|  | /// - The result of the first GenericOp has the same shape as the iteration | 
|  | ///   space of the GenericOp. The body of the op yields as many values as the | 
|  | ///   original op plus all the results of the peeled operation. | 
|  | /// - The second GenericOp has as many operands as the original operation plus | 
|  | /// all the results of the first Generic Op. It has the same number of yields as | 
|  | /// the original op. | 
|  | /// - If the result of the peeled operation was yielded by the original | 
|  | ///   GenericOp the uses of the corresponding results will be replaced with the | 
|  | ///   result of the first GenericOp created. | 
|  | /// | 
|  | ///  Example | 
|  | /// | 
|  | /// ```mlir | 
|  | ///  %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) | 
|  | ///      outs(%init0, %init1 : ...) { | 
|  | ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): | 
|  | ///      %0 = <s0> %b0, %b1 : ... | 
|  | ///      %1 = <s1> %0, %b2 : ... | 
|  | ///      linalg.yield %0, %1 : ... | 
|  | ///  } -> (..., ...) | 
|  | ///  return %result#0, %result#1 | 
|  | /// ``` | 
|  | /// | 
|  | /// gets split into | 
|  | /// | 
|  | /// ```mlir | 
|  | /// %init = tensor.empty ... | 
|  | /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) | 
|  | ///      outs(%init0, %init1, %init : ...) | 
|  | ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): | 
|  | ///      %0 = <s0> %b0, %b1 : ... | 
|  | ///      linalg.yield %0, %..., %0 : ... | 
|  | ///  } -> (..., ..., ...) | 
|  | /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) | 
|  | ///      outs(%init0, %init1 : ...) { | 
|  | ///    ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): | 
|  | ///      %1 = <s1> %b3, %b2 : ... | 
|  | ///      linalg.yield %..., %1 : ... | 
|  | ///  } -> (..., ...) | 
|  | ///  return %op0#0, %op1#1 | 
|  | /// ``` | 
|  | /// | 
|  | /// After canonicalization this is expected to be | 
|  | /// | 
|  | /// ```mlir | 
|  | /// %init = tensor.empty ... | 
|  | /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) | 
|  | ///      outs(%init : ...) | 
|  | ///    ^bb0(%b0: ... , %b1: ... , %b2: ...): | 
|  | ///      %0 = <s0> %b0, %b1 : ... | 
|  | ///      linalg.yield %0 : ... | 
|  | ///  } -> ... | 
|  | /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) | 
|  | ///      outs(%init1 : ...) { | 
|  | ///    ^bb0(%b0: ... , %b1: ... , %b2: ...): | 
|  | ///      %1 = <s1> %b1, %b0 : ... | 
|  | ///      linalg.yield %..., %1 : ... | 
|  | ///  } -> ... | 
|  | ///  return %op0, %op1 | 
|  | /// ``` | 
|  | struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> { | 
|  | using OpRewritePattern<GenericOp>::OpRewritePattern; | 
|  |  | 
|  | LogicalResult matchAndRewrite(GenericOp genericOp, | 
|  | PatternRewriter &rewriter) const override; | 
|  |  | 
|  | private: | 
|  | /// Helper method to create a generic op for the peeled scalar operation. The | 
|  | /// created op has an empty region. | 
|  | GenericOp createPeeledGenericOp(GenericOp genericOp, | 
|  | PatternRewriter &rewriter) const; | 
|  |  | 
|  | /// Helper method to create a generic op for the residual scalar operation. | 
|  | /// The created op has the same region as the original op. | 
|  | GenericOp createResidualGenericOp(GenericOp genericOp, | 
|  | GenericOp peeledGenericOp, | 
|  | PatternRewriter &rewriter) const; | 
|  | }; | 
|  | } // namespace | 
|  |  | 
|  | /// Helper method to compute the range of a generic op. | 
|  | static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b, | 
|  | GenericOp op) { | 
|  | OpBuilder::InsertionGuard g(b); | 
|  | b.setInsertionPoint(op); | 
|  | Location loc = op.getLoc(); | 
|  | auto allShapesSizes = | 
|  | cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc); | 
|  | AffineMap map = op.getShapesToLoopsMap(); | 
|  | IRRewriter rewriter(b); | 
|  | return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, | 
|  | allShapesSizes); | 
|  | } | 
|  |  | 
|  | /// Helper method to permute the list of `values` based on the `map`. | 
|  | SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values, | 
|  | AffineMap map) { | 
|  | assert(map.isPermutation()); | 
|  | SmallVector<OpFoldResult> permutedValues(values.size()); | 
|  | for (const auto &position : | 
|  | llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) { | 
|  | return cast<AffineDimExpr>(expr).getPosition(); | 
|  | }))) | 
|  | permutedValues[position.value()] = values[position.index()]; | 
|  | return permutedValues; | 
|  | } | 
|  |  | 
|  | /// Get zero value for an element type. | 
|  | static Value getZero(OpBuilder &b, Location loc, Type elementType) { | 
|  | assert(elementType.isIntOrIndexOrFloat() && | 
|  | "expected scalar type while computing zero value"); | 
|  | if (isa<IntegerType>(elementType)) | 
|  | return b.create<arith::ConstantIntOp>(loc, 0, elementType); | 
|  | if (elementType.isIndex()) | 
|  | return b.create<arith::ConstantIndexOp>(loc, 0); | 
|  | // Assume float. | 
|  | auto floatType = cast<FloatType>(elementType); | 
|  | return b.create<arith::ConstantFloatOp>( | 
|  | loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); | 
|  | } | 
|  |  | 
|  | GenericOp | 
|  | DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, | 
|  | PatternRewriter &rewriter) const { | 
|  | Block *body = genericOp.getBody(); | 
|  | Operation *peeledScalarOperation = &(*body->begin()); | 
|  | SmallVector<AffineMap> peeledGenericOpIndexingMaps = | 
|  | genericOp.getIndexingMapsArray(); | 
|  |  | 
|  | /// Compute the loop ranges for operation. This is the shape of the result of | 
|  | /// the generic op for the peeled operation. | 
|  | Location loc = genericOp.getLoc(); | 
|  | SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp); | 
|  | SmallVector<Value> newInitValues; | 
|  | SmallVector<Type> newResultTypes; | 
|  |  | 
|  | // Add as many new results as the number of results of the peeled scalar op. | 
|  | for (auto scalarOpResult : peeledScalarOperation->getResults()) { | 
|  | // If the result is yielded by the original op, use the operand, indexing | 
|  | // map and result type that correspond to the yielded value. | 
|  |  | 
|  | std::optional<unsigned> resultNumber; | 
|  | for (auto *user : scalarOpResult.getUsers()) { | 
|  | if (auto yieldOp = dyn_cast<YieldOp>(user)) { | 
|  | // Find the first use of the `scalarOpResult` in the yield op. | 
|  | for (OpOperand &yieldOperand : yieldOp->getOpOperands()) { | 
|  | if (yieldOperand.get() == scalarOpResult) { | 
|  | resultNumber = yieldOperand.getOperandNumber(); | 
|  | break; | 
|  | } | 
|  | } | 
|  | assert(resultNumber && "unable to find use of a value in its user"); | 
|  | break; | 
|  | } | 
|  | } | 
|  | if (resultNumber) { | 
|  | newInitValues.push_back( | 
|  | genericOp.getDpsInitOperand(*resultNumber)->get()); | 
|  | OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber)); | 
|  | newResultTypes.push_back(result.getType()); | 
|  | peeledGenericOpIndexingMaps.push_back( | 
|  | genericOp.getIndexingMapMatchingResult(result)); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // Fall back path, use an `init_tensor` and identity indexing map. | 
|  | AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size()); | 
|  | Value emptyTensor = | 
|  | rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType()); | 
|  | newInitValues.push_back(emptyTensor); | 
|  | newResultTypes.push_back(emptyTensor.getType()); | 
|  | peeledGenericOpIndexingMaps.push_back(indexingMap); | 
|  | } | 
|  |  | 
|  | /// Create the peeled generic op with an empty body. | 
|  | SmallVector<Value> outsOperands = genericOp.getOutputs(); | 
|  | outsOperands.append(newInitValues.begin(), newInitValues.end()); | 
|  | SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes()); | 
|  | resultTypes.append(newResultTypes.begin(), newResultTypes.end()); | 
|  | auto indexingMapAttr = | 
|  | rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); | 
|  | return rewriter.create<GenericOp>( | 
|  | loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr, | 
|  | genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, | 
|  | [](OpBuilder, Location, ValueRange) {}); | 
|  | } | 
|  |  | 
|  | GenericOp | 
|  | DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, | 
|  | GenericOp peeledGenericOp, | 
|  | PatternRewriter &rewriter) const { | 
|  | /// Append all results from the peeledGenericOps as `ins` operand for the | 
|  | /// residual generic op. | 
|  | SmallVector<Value> residualGenericOpOperands = genericOp.getInputs(); | 
|  | unsigned origNumResults = genericOp.getNumResults(); | 
|  | unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); | 
|  | SmallVector<Value> extraIns; | 
|  | for (auto resultNum : | 
|  | llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) | 
|  | extraIns.push_back(peeledGenericOp->getResult(resultNum)); | 
|  | residualGenericOpOperands.append(extraIns); | 
|  |  | 
|  | /// Add indexing maps for the newly added operands. Use the same map | 
|  | /// as those used for the new results of the peeledGenericOp. | 
|  | auto indexingMaps = llvm::to_vector( | 
|  | llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) { | 
|  | return genericOp.getMatchingIndexingMap(operand); | 
|  | })); | 
|  | for (auto resultNum : | 
|  | llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) { | 
|  | OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum)); | 
|  | indexingMaps.push_back( | 
|  | peeledGenericOp.getIndexingMapMatchingResult(result)); | 
|  | } | 
|  | for (OpOperand &outOperand : genericOp.getDpsInitsMutable()) | 
|  | indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand)); | 
|  |  | 
|  | auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); | 
|  | return rewriter.create<GenericOp>( | 
|  | genericOp->getLoc(), genericOp->getResultTypes(), | 
|  | residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr, | 
|  | genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, | 
|  | [](OpBuilder, Location, ValueRange) {}); | 
|  | } | 
|  |  | 
|  | LogicalResult | 
|  | DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, | 
|  | PatternRewriter &rewriter) const { | 
|  | /// For now only match on operations where the iterator types are all parallel | 
|  | if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { | 
|  | return rewriter.notifyMatchFailure(genericOp, | 
|  | "unhandled decomposition of operation " | 
|  | "with non-parallel iterator types"); | 
|  | } | 
|  | // TODO: this could be generalized to handle `linalg.generic` with buffer | 
|  | // operands too but requires allocation for intermediates. Punt on this for | 
|  | // now. | 
|  | if (!genericOp.hasPureTensorSemantics()) { | 
|  | return rewriter.notifyMatchFailure( | 
|  | genericOp, "only operations with tensor semantics are handled"); | 
|  | } | 
|  |  | 
|  | if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) { | 
|  | return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation(); | 
|  | })) { | 
|  | return rewriter.notifyMatchFailure( | 
|  | genericOp, "unhandled decomposition of generic op with out operand not " | 
|  | "accessed using a permutation"); | 
|  | } | 
|  |  | 
|  | /// If the op has only a single statement (apart from the yield), do nothing. | 
|  | Block *body = genericOp.getBody(); | 
|  | if (body->getOperations().size() <= 2) { | 
|  | return rewriter.notifyMatchFailure(genericOp, | 
|  | "operation has less than 3 statements"); | 
|  | } | 
|  |  | 
|  | /// Check that the peeled statement has a scalar element type. | 
|  | if (llvm::any_of(body->getOperations().begin()->getResultTypes(), | 
|  | [](Type t) { return !t.isIntOrIndexOrFloat(); })) { | 
|  | return rewriter.notifyMatchFailure( | 
|  | &(*body->getOperations().begin()), | 
|  | "expected return type to be only int, index or float"); | 
|  | } | 
|  |  | 
|  | GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); | 
|  | GenericOp residualGenericOp = | 
|  | createResidualGenericOp(genericOp, peeledGenericOp, rewriter); | 
|  |  | 
|  | /// Move the first statement of the original operation into the body of the | 
|  | /// generic op for the peeled operation. | 
|  | Block *peeledGenericOpBody = peeledGenericOp.getBody(); | 
|  | Block *residualGenericOpBody = residualGenericOp.getBody(); | 
|  | assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && | 
|  | "expected split generic ops to have empty region"); | 
|  | peeledGenericOpBody->getOperations().splice( | 
|  | peeledGenericOpBody->begin(), body->getOperations(), body->begin()); | 
|  | residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(), | 
|  | body->getOperations()); | 
|  |  | 
|  | Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); | 
|  | auto *yieldOp = residualGenericOpBody->getTerminator(); | 
|  | { | 
|  | // Yield all the result of the peeled scalar operation. | 
|  | OpBuilder::InsertionGuard g(rewriter); | 
|  | rewriter.setInsertionPointToEnd(peeledGenericOpBody); | 
|  | SmallVector<Value> yieldedVals; | 
|  | for (auto origYield : yieldOp->getOperands()) { | 
|  | if (origYield.getDefiningOp() == peeledScalarOperation) { | 
|  | yieldedVals.push_back(origYield); | 
|  | } else { | 
|  | // Do not materialize any new ops inside of the decomposed LinalgOp, | 
|  | // as that would trigger another application of the rewrite pattern | 
|  | // (infinite loop). | 
|  | OpBuilder::InsertionGuard g(rewriter); | 
|  | rewriter.setInsertionPoint(peeledGenericOp); | 
|  | yieldedVals.push_back( | 
|  | getZero(rewriter, genericOp.getLoc(), origYield.getType())); | 
|  | } | 
|  | } | 
|  | yieldedVals.append(llvm::to_vector( | 
|  | llvm::map_range(peeledScalarOperation->getResults(), | 
|  | [](OpResult opr) -> Value { return opr; }))); | 
|  | rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals); | 
|  | } | 
|  |  | 
|  | /// In the split operations, replace block arguments uses that refer to | 
|  | /// original operation to the block arguments of the newly created operation. | 
|  | unsigned origNumInputs = genericOp.getNumDpsInputs(); | 
|  | for (const auto &inputBlockArg : | 
|  | llvm::enumerate(genericOp.getBody()->getArguments())) { | 
|  | Value residualOpReplacementArg = | 
|  | residualGenericOpBody->getArgument(inputBlockArg.index()); | 
|  | rewriter.replaceUsesWithIf( | 
|  | inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) { | 
|  | return use.getOwner()->getBlock() == residualGenericOpBody; | 
|  | }); | 
|  |  | 
|  | Value peeledOpReplacementArg = | 
|  | peeledGenericOpBody->getArgument(inputBlockArg.index()); | 
|  | rewriter.replaceUsesWithIf( | 
|  | inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) { | 
|  | return use.getOwner()->getBlock() == peeledGenericOpBody; | 
|  | }); | 
|  | } | 
|  |  | 
|  | /// Before fixing up the residual operation, track what values are yielded. If | 
|  | /// any of those are from the peeled scalar operation, the uses of the | 
|  | /// corresponding result have to be remapped to result of the generic op for | 
|  | /// the peeled operation. | 
|  | SmallVector<Value> replacements; | 
|  | for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { | 
|  | OpResult opr = dyn_cast<OpResult>(yieldValue.value()); | 
|  | if (!opr || opr.getOwner() != peeledScalarOperation) | 
|  | replacements.push_back(residualGenericOp.getResult(yieldValue.index())); | 
|  | else | 
|  | replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); | 
|  | } | 
|  |  | 
|  | /// Update all uses of the peeled scalar operation results in the residual op | 
|  | /// to the newly added arguments. | 
|  | { | 
|  | SmallVector<Value> scalarReplacements; | 
|  | unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); | 
|  | scalarReplacements.reserve(peeledScalarOpNumResults); | 
|  | for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults)) | 
|  | scalarReplacements.push_back( | 
|  | residualGenericOpBody->getArgument(num + origNumInputs)); | 
|  | bool allUsesReplaced = false; | 
|  | rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements, | 
|  | residualGenericOpBody, &allUsesReplaced); | 
|  | assert(!allUsesReplaced && | 
|  | "peeled scalar operation is erased when it wasnt expected to be"); | 
|  | } | 
|  |  | 
|  | // Replace the original operation | 
|  | rewriter.replaceOp(genericOp, replacements); | 
|  | return success(); | 
|  | } | 
|  |  | 
|  | void mlir::linalg::populateDecomposeLinalgOpsPattern( | 
|  | RewritePatternSet &patterns, bool removeDeadArgsAndResults) { | 
|  | patterns.insert<DecomposeLinalgOp>(patterns.getContext()); | 
|  | // Add the patterns to clean up the dead operands and results. | 
|  | if (removeDeadArgsAndResults) | 
|  | populateEraseUnusedOperandsAndResultsPatterns(patterns); | 
|  | } |