| //===- Sparsification.cpp - Implementation of sparsification --------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements converting sparse tensor types to actual sparse code. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Utils/CodegenEnv.h" |
| #include "Utils/CodegenUtils.h" |
| #include "Utils/LoopEmitter.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/AffineExprVisitor.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/TensorEncoding.h" |
| #include "llvm/ADT/SmallBitVector.h" |
| |
| #include <optional> |
| |
| using namespace mlir; |
| using namespace mlir::sparse_tensor; |
| |
| //===----------------------------------------------------------------------===// |
| // Sparsifier analysis methods. |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns true iff affine expression is invariant. Sets the |
| /// parameter `isCurrentLoop` when expression just became invariant. |
| static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) { |
| switch (a.getKind()) { |
| case AffineExprKind::DimId: { |
| const LoopId i = cast<AffineDimExpr>(a).getPosition(); |
| if (i + 1 == curr) { |
| isCurrentLoop = true; |
| return true; // becomes invariant at current loop |
| } |
| return i < curr; // invariant when already generated |
| } |
| case AffineExprKind::Add: |
| case AffineExprKind::Mul: { |
| auto binOp = cast<AffineBinaryOpExpr>(a); |
| return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) && |
| isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop); |
| } |
| default: { |
| assert(isa<AffineConstantExpr>(a)); |
| return true; |
| } |
| } |
| } |
| |
| /// Helper method to inspect affine expressions. Rejects cases where the |
| /// same index is used more than once. Also rejects compound affine |
| /// expressions in sparse dimensions. |
| static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, |
| LevelType lt, bool setLvlFormat = true) { |
| switch (a.getKind()) { |
| case AffineExprKind::DimId: { |
| const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition()); |
| if (!isUndefLT(merger.getLvlType(tid, idx))) |
| return false; // used more than once |
| if (setLvlFormat) |
| merger.setLevelAndType(tid, idx, lvl, lt); |
| return true; |
| } |
| case AffineExprKind::Add: |
| case AffineExprKind::Mul: |
| case AffineExprKind::Constant: { |
| assert(lt.hasDenseSemantic()); |
| if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) { |
| // We do not set dim level format for affine expression like d0 + d1 on |
| // either loop index at d0 or d1. We continue the recursion merely to |
| // check whether current affine is admissible or not. |
| return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) && |
| findAffine(merger, tid, lvl, binOp.getRHS(), lt, false); |
| } |
| // Falls through when it is a constant Affine |
| return true; |
| } |
| default: |
| return false; |
| } |
| } |
| |
| /// Helper method to inspect affine expressions for index variable reduction |
| /// based codegen. It finds the dependent index set for all tensor levels in the |
| /// current expression we are generating. |
| /// |
| /// For example, when handling A[i+j][j+k], we build the two way mapping in |
| /// merger between (tensor, level) pairs and their dependent index variable set: |
| /// A_0 <=> [i, j] and A_1 <=> [j, k] |
| /// |
| /// It rejects cases (returns false) |
| /// 1st, when the same index is used more than once, e.g., A[i+j][i] |
| /// 2nd, when multiplication is used in the non-trivial index expression. |
| /// 3rd, when a constant operand is used in the non-trivial index expression. |
| /// |
| /// TODO: constant should be easy to handle. |
| static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, |
| AffineExpr a, LevelType lt, bool isSubExp = false, |
| int64_t coefficient = 1) { |
| switch (a.getKind()) { |
| case AffineExprKind::DimId: { |
| // Only allow positive coefficients on AffineDimExpr. |
| if (coefficient <= 0) |
| return false; |
| |
| const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition()); |
| if (!isUndefLT(merger.getLvlType(tensor, idx))) |
| return false; // used more than once, e.g., A[i][i] |
| |
| // TODO: Generalizes the following two cases. A[i] (with trivial index |
| // expression) can be treated as a special affine index expression. We do |
| // not necessarily need to differentiate them. |
| if (!isSubExp) { |
| assert(coefficient == 1); |
| merger.setLevelAndType(tensor, idx, lvl, lt); |
| } |
| |
| if (isSubExp) { |
| // The current loops appears in more than one affine expressions on the |
| // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is |
| // used twice. |
| if (merger.hasDependentLvl(idx, tensor)) { |
| // TODO: This can be supported by coiterate slices if the loop idx is |
| // appeared on affine index for different tensor, or take slice on |
| // multiple dimensions when it is on the same tensor. |
| // E.g., |
| // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] |
| // d0_1 = getNextSliceOffset t0 along lvl0 |
| // d0_2 = getNextSliceOffset t1 along lvl0 |
| // if d0_1 == d0_2 then d0 = d0_1 = d0_1 |
| // else increase min(d0_1, d0_2). |
| return false; |
| } |
| merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient); |
| } |
| return true; |
| } |
| case AffineExprKind::Constant: |
| case AffineExprKind::Mul: { |
| // TODO: Support index expression like `2 * d0`, we now only support more |
| // complicated cases like `2 * d0 + d1`. |
| if (!isSubExp) |
| return false; |
| |
| // TODO: Support Constant AffineExp for slice-based codegen |
| if (isa<AffineConstantExpr>(a)) |
| llvm_unreachable("Not yet implemented"); |
| |
| auto binOp = cast<AffineBinaryOpExpr>(a); |
| auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); |
| if (isa<AffineConstantExpr>(rhs)) |
| std::swap(lhs, rhs); |
| // Must be in form of `constant * d`. |
| assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs)); |
| int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue(); |
| return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient); |
| } |
| case AffineExprKind::Add: { |
| auto binOp = cast<AffineBinaryOpExpr>(a); |
| return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) && |
| findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true); |
| } |
| default: |
| return false; |
| } |
| } |
| |
| /// Gets the total number of compound affine expressions in the |
| /// `getMatchingIndexingMap` for the given tensor. For the following inputs: |
| /// |
| /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) |
| /// |
| /// Returns 1 (because the first level is compressed and its corresponding |
| /// indexing-expression is `d0 + d1`) |
| static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, |
| Value tensor) { |
| // The `tensor` is not guaranteed to have `RankedTensorType`, therefore |
| // we can't use `getRankedTensorType`/`getSparseTensorType` here. |
| // However, we don't need to handle `StorageSpecifierType`, so we |
| // can use `SparseTensorType` once we guard against non-tensors. |
| const auto rtp = dyn_cast<RankedTensorType>(tensor.getType()); |
| if (!rtp) |
| return 0; |
| const SparseTensorType stt(rtp); |
| |
| const Level lvlRank = stt.getLvlRank(); |
| const auto exprs = map.getResults(); |
| assert(static_cast<Dimension>(exprs.size()) == lvlRank && |
| "AffineMap does not have dimension-rank many results"); |
| unsigned num = 0; |
| for (Level l = 0; l < lvlRank; l++) { |
| if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic()) |
| num++; |
| } |
| return num; |
| } |
| |
| /// Gets the total number of sparse levels with compound affine |
| /// expressions, summed over all operands of the `GenericOp`. |
| static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { |
| unsigned num = 0; |
| for (OpOperand &t : op->getOpOperands()) |
| num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t), |
| t.get()); |
| return num; |
| } |
| |
| // Returns true iff output has nontrivial affine indices. |
| static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { |
| OpOperand *out = op.getDpsInitOperand(0); |
| if (getSparseTensorType(out->get()).isAllDense()) |
| return false; |
| return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out), |
| out->get()); |
| } |
| |
| /// Helper method to inspect sparse encodings in the tensor types. |
| /// Fills the per-dimension sparsity information for all tensors. |
| /// Returns true if the sparse annotations and affine subscript |
| /// expressions of all tensors are admissible. Returns false if |
| /// no annotations are found or inadmissible constructs occur. |
| /// We currently support two different ways to handle non-trivial index |
| /// expression on sparse tensors, and they accept different affine expressions. |
| /// When using dependent index reducton-based approach, it currently only |
| /// supports affine addition index expression. |
| static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { |
| bool annotated = false; |
| for (OpOperand &t : env.op()->getOpOperands()) { |
| const TensorId tid = env.makeTensorId(t.getOperandNumber()); |
| const auto map = env.op().getMatchingIndexingMap(&t); |
| const auto enc = getSparseTensorEncoding(t.get().getType()); |
| if (enc) |
| annotated = true; |
| const Level lvlRank = map.getNumResults(); |
| assert(!enc || lvlRank == enc.getLvlRank()); |
| assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank); |
| // We only need to do index reduction if there is at least one |
| // non-trivial index expression on sparse levels. If all non-trivial |
| // index expression is on dense levels, we can efficiently rely on |
| // the random access to locate the element. |
| bool needIdxReduc = |
| enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0; |
| // If then current tensor being inspected requires affine index, it need |
| // to be sliced. |
| for (Level l = 0; l < lvlRank; l++) { |
| const AffineExpr a = map.getResult(l); |
| const LevelType lt = enc.getLvlType(l); |
| if (idxReducBased && needIdxReduc) { |
| if (!findDepIdxSet(env.merger(), tid, l, a, lt)) |
| return false; // inadmissible affine expression |
| } else { |
| if (!findAffine(env.merger(), tid, l, a, lt)) |
| return false; // inadmissible affine expression |
| } |
| } |
| } |
| return annotated; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Sparsifier synthesis methods (statements and expressions). |
| //===----------------------------------------------------------------------===// |
| |
| /// Local bufferization of all dense and sparse data structures. |
| static void genBuffers(CodegenEnv &env, OpBuilder &builder) { |
| linalg::GenericOp op = env.op(); |
| Location loc = op.getLoc(); |
| assert(op.getNumOperands() == op.getNumDpsInputs() + 1); |
| |
| SmallVector<Range, 4> loopRange = |
| llvm::cast<linalg::LinalgOp>(op.getOperation()) |
| .createLoopRanges(builder, loc); |
| |
| env.emitter().initializeLoopEmit( |
| builder, loc, |
| /// Generates buffer for the output tensor. |
| /// Note that all sparse kernels assume that when all elements are written |
| /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized |
| /// to all zeroes and only nonzeroes values are computed and written out. |
| /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used |
| /// for the updates and no assumption on the original contents of the |
| /// output buffer is necessary. |
| [&op](OpBuilder &builder, Location loc, Value memref, |
| Value tensor) -> Value { |
| // Must not be a sparse tensor. |
| assert(!getSparseTensorEncoding(tensor.getType())); |
| // Two output tensor references should point to the same object. |
| OpOperand *lhs = op.getDpsInitOperand(0); |
| assert(lhs->get() == tensor); |
| // An output tensor can simply materialize from the buffer of the tensor |
| // that appears in the outs() clause. For updates, this has the |
| // advantage that only the nonzero value are involved in the |
| // computation, keeping the operation O(nnz). In all other cases, we are |
| // forced to zero out the buffer to enforce the assumption above, which |
| // may negatively impact running complexity (viz. O(n^2 + nnz) vs. |
| // O(nnz) for matrices). |
| // TODO: use better analysis to avoid zeroing out the buffer? |
| bool isInit = op.isInitTensor(lhs); |
| Value init = memref; |
| if (!isInit) { |
| Value zero = constantZero(builder, loc, |
| getElementTypeOrSelf(tensor.getType())); |
| builder.create<linalg::FillOp>(loc, ValueRange{zero}, |
| ValueRange{init}); |
| } |
| return init; |
| }, |
| [&loopRange](OpBuilder &b, Location loc, Level l) { |
| assert(l < loopRange.size()); |
| return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size); |
| }); |
| } |
| |
| /// Generates index for load/store on sparse tensor. |
| static Value genIndex(CodegenEnv &env, OpOperand *t) { |
| const auto map = env.op().getMatchingIndexingMap(t); |
| const auto stt = getSparseTensorType(t->get()); |
| const Level lvlRank = stt.getLvlRank(); |
| assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| const AffineExpr a = map.getResult(lvlRank - 1); |
| assert(a.getKind() == AffineExprKind::DimId); |
| const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition()); |
| return env.getLoopVar(idx); |
| } |
| |
| /// Generates subscript for load/store on a dense or sparse tensor. |
| static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
| SmallVectorImpl<Value> &args) { |
| const Location loc = env.op().getLoc(); |
| const TensorId tid = env.makeTensorId(t->getOperandNumber()); |
| const auto map = env.op().getMatchingIndexingMap(t); |
| const auto stt = getSparseTensorType(t->get()); |
| if (stt.hasEncoding()) { |
| // For sparse tensors we only push the last-level's position onto `args`. |
| const auto pos = env.emitter().getValPosits(tid); |
| assert(!pos.empty()); |
| args.append(pos); |
| } else { |
| // For dense tensors we push all level's coordinates onto `args`. |
| const Level lvlRank = stt.getLvlRank(); |
| assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| for (Level l = 0; l < lvlRank; l++) { |
| const auto lvlExpr = map.getResult(l); |
| const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr); |
| args.push_back(lvlCrd); |
| } |
| } |
| return env.emitter().getValBuffer()[tid]; |
| } |
| |
| /// Generates insertion code to implement dynamic tensor load. |
| static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, |
| OpOperand *t) { |
| linalg::GenericOp op = env.op(); |
| Location loc = op.getLoc(); |
| // Direct lexicographic coordinate order, tensor loads as zero. |
| if (!env.isExpand()) { |
| Type tp = getElementTypeOrSelf(t->get().getType()); |
| return constantZero(builder, loc, tp); |
| } |
| // Load from expanded access pattern. |
| Value index = genIndex(env, t); |
| return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index); |
| } |
| |
| /// Generates insertion code to implement dynamic tensor load for reduction. |
| static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, |
| OpOperand *t) { |
| linalg::GenericOp op = env.op(); |
| Location loc = op.getLoc(); |
| Value identity = env.getCustomRedId(); |
| // Direct lexicographic coordinate order, tensor loads as identity. |
| if (!env.isExpand()) |
| return identity; |
| // Load from expanded access pattern if filled, identity otherwise. |
| Value values = env.getExpandValues(); |
| Value filled = env.getExpandFilled(); |
| Value index = genIndex(env, t); |
| Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
| Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index); |
| return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity); |
| } |
| |
| static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, |
| Value sparseOut, ValueRange ivs, Value v) { |
| scf::IfOp condInsert = |
| builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true); |
| // True branch. |
| builder.setInsertionPointToStart(condInsert.thenBlock()); |
| Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs); |
| builder.create<scf::YieldOp>(loc, res); |
| // False branch. |
| builder.setInsertionPointToStart(condInsert.elseBlock()); |
| builder.create<scf::YieldOp>(loc, sparseOut); |
| // Value assignment. |
| builder.setInsertionPointAfter(condInsert); |
| return condInsert.getResult(0); |
| } |
| |
| /// Generates insertion code to implement dynamic tensor store. |
| static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, |
| Value rhs) { |
| linalg::GenericOp op = env.op(); |
| Location loc = op.getLoc(); |
| // Direct insertion in lexicographic coordinate order. |
| if (!env.isExpand()) { |
| const LoopId numLoops = op.getRank(t); |
| // Retrieves the first `numLoop` induction variables. |
| SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end( |
| env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops)); |
| Value chain = env.getInsertionChain(); |
| if (env.isValidLexInsert()) { |
| // Generates runtime check for a valid lex during reduction, |
| // to avoid inserting the identity value for empty reductions. |
| // if (validLexInsert) then |
| // insert(rhs) into chain |
| // return updated chain |
| // else |
| // return unmodified chain |
| Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(), |
| chain, ivs, rhs); |
| env.updateInsertionChain(out); |
| } else { |
| Value sparseOut; |
| if (!hasAnySparseType(env.op().getInputs().getTypes())) { |
| // This is an all-dense -> sparse kernel, test rhs != 0 before |
| // insertion. |
| Value nz = genIsNonzero(builder, loc, rhs); |
| sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs); |
| } else { |
| sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs); |
| } |
| // Generates regular insertion chain. |
| env.updateInsertionChain(sparseOut); |
| } |
| return; |
| } |
| // Generates insertion code along expanded access pattern. |
| // if (!expFilled[i]) then |
| // expFilled[i] = true |
| // expAdded[inserts++] = i |
| // endif |
| // values[i] = rhs |
| Value values = env.getExpandValues(); |
| Value filled = env.getExpandFilled(); |
| Value added = env.getExpandAdded(); |
| Value count = env.getExpandCount(); |
| Value index = genIndex(env, t); |
| Value fval = constantI1(builder, loc, false); |
| Value tval = constantI1(builder, loc, true); |
| // If statement. |
| Value isFilled = builder.create<memref::LoadOp>(loc, filled, index); |
| Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| isFilled, fval); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, |
| /*else=*/true); |
| // True branch. |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| builder.create<memref::StoreOp>(loc, tval, filled, index); |
| builder.create<memref::StoreOp>(loc, index, added, count); |
| Value one = constantIndex(builder, loc, 1); |
| Value add = builder.create<arith::AddIOp>(loc, count, one); |
| builder.create<scf::YieldOp>(loc, add); |
| // False branch. |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| builder.create<scf::YieldOp>(loc, count); |
| builder.setInsertionPointAfter(ifOp); |
| // Value assignment. |
| env.updateExpandCount(ifOp.getResult(0)); |
| builder.create<memref::StoreOp>(loc, rhs, values, index); |
| } |
| |
| /// Generates a load on a dense or sparse tensor. |
| static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { |
| // Test if the load was hoisted to a higher loop nest. |
| Value val = env.exp(exp).val; |
| if (val) |
| return val; |
| // Load during insertion. |
| linalg::GenericOp op = env.op(); |
| OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); |
| if (env.isSparseOutput(t)) { |
| if (env.isCustomReduc()) |
| return genInsertionLoadReduce(env, builder, t); |
| return genInsertionLoad(env, builder, t); |
| } |
| // Actual load. |
| SmallVector<Value> args; |
| Value ptr = genSubscript(env, builder, t, args); |
| return builder.create<memref::LoadOp>(op.getLoc(), ptr, args); |
| } |
| |
| /// Generates a store on a dense or sparse tensor. |
| static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| Value rhs) { |
| // Only unary and binary are allowed to return an uninitialized rhs |
| // to indicate missing output. Or otherwise a custom reduction that |
| // received no value to accumulate. |
| if (!rhs) { |
| assert(env.exp(exp).kind == TensorExp::Kind::kUnary || |
| env.exp(exp).kind == TensorExp::Kind::kBinary || |
| env.exp(exp).kind == TensorExp::Kind::kReduce); |
| return; |
| } |
| // Test if this is a scalarized reduction. |
| if (env.isReduc()) { |
| env.updateReduc(rhs); |
| return; |
| } |
| // Regular store. |
| linalg::GenericOp op = env.op(); |
| Location loc = op.getLoc(); |
| OpOperand *t = op.getDpsInitOperand(0); |
| if (!env.isSparseOutput(t)) { |
| SmallVector<Value> args; |
| Value ptr = genSubscript(env, builder, t, args); |
| builder.create<memref::StoreOp>(loc, rhs, ptr, args); |
| return; |
| } |
| // Store during sparse insertion. |
| if (env.exp(exp).kind != TensorExp::Kind::kSelect) { |
| genInsertionStore(env, builder, t, rhs); |
| return; |
| } |
| // Select operation insertion. |
| Value chain = env.getInsertionChain(); |
| scf::IfOp ifOp = |
| builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true); |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| // Existing value was preserved to be used here. |
| assert(env.exp(exp).val); |
| Value v0 = env.exp(exp).val; |
| genInsertionStore(env, builder, t, v0); |
| env.merger().clearExprValue(exp); |
| // Yield modified insertion chain along true branch. |
| Value mchain = env.getInsertionChain(); |
| builder.create<scf::YieldOp>(op.getLoc(), mchain); |
| // Yield original insertion chain along false branch. |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| builder.create<scf::YieldOp>(loc, chain); |
| // Done with if statement. |
| env.updateInsertionChain(ifOp->getResult(0)); |
| builder.setInsertionPointAfter(ifOp); |
| } |
| |
| /// Generates an invariant value. |
| inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { |
| return env.exp(exp).val; |
| } |
| |
| /// Semi-ring branches are simply inlined by the sparsifier. Prior |
| /// analysis has verified that all computations are "local" to the inlined |
| /// branch or otherwise invariantly defined outside the loop nest, with the |
| /// exception of index computations, which need to be relinked to actual |
| /// inlined cloned code. |
| static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, |
| Value e) { |
| if (auto arg = dyn_cast<BlockArgument>(e)) { |
| // Direct arguments of the original linalg op must be converted |
| // into dense tensor loads. Note that we should not encounter |
| // anything else. This needs to be verified by semi-ring ops. |
| linalg::GenericOp op = env.op(); |
| if (arg.getOwner()->getParentOp() == op) { |
| const TensorId tid = env.makeTensorId(arg.getArgNumber()); |
| OpOperand *t = &op->getOpOperand(tid); |
| assert(!getSparseTensorType(t->get()).hasEncoding()); // dense! |
| SmallVector<Value> args; |
| Value ptr = genSubscript(env, rewriter, t, args); |
| return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); |
| } |
| } else if (Operation *def = e.getDefiningOp()) { |
| // Handle index computation. |
| if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) |
| return env.getLoopVar(env.makeLoopId(indexOp.getDim())); |
| // When still defined in new body, recurse into operands. |
| if (def->getBlock() == block) { |
| rewriter.setInsertionPoint(def); |
| for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { |
| rewriter.modifyOpInPlace(def, [&]() { |
| def->setOperand( |
| i, relinkBranch(env, rewriter, block, def->getOperand(i))); |
| }); |
| } |
| } |
| } |
| return e; |
| } |
| |
| /// Recursively generates tensor expression. |
| static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { |
| if (e == ::mlir::sparse_tensor::detail::kInvalidId) |
| return Value(); |
| |
| linalg::GenericOp op = env.op(); |
| Location loc = op.getLoc(); |
| const TensorExp &exp = env.exp(e); |
| const auto kind = exp.kind; |
| if (kind == TensorExp::Kind::kTensor) |
| return genTensorLoad(env, rewriter, e); |
| if (kind == TensorExp::Kind::kInvariant) |
| return genInvariantValue(env, e); |
| if (kind == TensorExp::Kind::kLoopVar) |
| return env.getLoopVar(exp.loop); |
| |
| if (kind == TensorExp::Kind::kReduce) |
| env.startCustomReduc(e); // enter custom |
| |
| // If either lhs/rhs is a synthetic zero, we infer the type for the zero value |
| // based on the type of the other operand. |
| Value v0, v1; |
| if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && |
| env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) { |
| v1 = genExp(env, rewriter, exp.children.e1); |
| v0 = constantZero(rewriter, loc, v1.getType()); |
| } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && |
| env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) { |
| v0 = genExp(env, rewriter, exp.children.e0); |
| v1 = constantZero(rewriter, loc, v0.getType()); |
| } else { |
| v0 = genExp(env, rewriter, exp.children.e0); |
| v1 = genExp(env, rewriter, exp.children.e1); |
| } |
| |
| Value ee; |
| if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) { |
| // custom reduce did not receive a value |
| } else { |
| ee = env.merger().buildExp(rewriter, loc, e, v0, v1); |
| if (ee && |
| (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || |
| kind == TensorExp::Kind::kBinaryBranch || |
| kind == TensorExp::Kind::kReduce || |
| kind == TensorExp::Kind::kSelect)) { |
| OpBuilder::InsertionGuard guard(rewriter); |
| ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee); |
| } |
| } |
| |
| if (kind == TensorExp::Kind::kReduce) |
| env.endCustomReduc(); // exit custom |
| |
| if (kind == TensorExp::Kind::kSelect) |
| env.merger().setExprValue(e, v0); // Preserve value for later use. |
| |
| return ee; |
| } |
| |
| /// Hoists loop invariant tensor loads for which indices have been exhausted. |
| static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| LoopId curr, bool isStart) { |
| if (exp == ::mlir::sparse_tensor::detail::kInvalidId) |
| return; |
| if (env.exp(exp).kind == TensorExp::Kind::kTensor) { |
| // Inspect tensor indices. |
| linalg::GenericOp op = env.op(); |
| OpOperand &t = op->getOpOperand(env.exp(exp).tensor); |
| const auto map = op.getMatchingIndexingMap(&t); |
| const auto stt = getSparseTensorType(t.get()); |
| const Level lvlRank = stt.getLvlRank(); |
| assert(static_cast<Level>(map.getNumResults()) == lvlRank); |
| bool isCurrentLoop = curr == 0; // for scalar tensors |
| for (Level l = 0; l < lvlRank; l++) { |
| const AffineExpr a = map.getResult(l); |
| if (!isInvariantAffine(a, curr, /*out*/ isCurrentLoop)) |
| return; // still in play |
| } |
| // All exhausted at current level. |
| if (!isCurrentLoop) |
| return; |
| // Generate code for a scalarized reduction or invariant. Note that |
| // because custom reduction lhs may occur several times in the IR, |
| // we have a built-in safety for only initializing and wrapping-up |
| // the scalarized reduction once. |
| OpOperand *lhs = op.getDpsInitOperand(0); |
| if (lhs == &t) { |
| // Start or end a scalarized reduction. |
| if (isStart) { |
| if (env.isCustomReduc()) { |
| if (!env.isReduc()) |
| env.startReduc(exp, env.getCustomRedId()); |
| } else { |
| env.startReduc(exp, genTensorLoad(env, builder, exp)); |
| } |
| if (env.hasSparseOutput()) |
| env.startValidLexInsert( |
| constantI1(builder, env.op().getLoc(), false)); |
| } else { |
| if (!env.isCustomReduc() || env.isReduc()) |
| genTensorStore(env, builder, exp, env.endReduc()); |
| if (env.hasSparseOutput()) |
| env.endValidLexInsert(); |
| } |
| } else { |
| // Start or end loop invariant hoisting of a tensor load. |
| if (isStart) { |
| env.merger().setExprValue(exp, genTensorLoad(env, builder, exp)); |
| } else { |
| env.merger().clearExprValue(exp); |
| } |
| } |
| } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant && |
| env.exp(exp).kind != TensorExp::Kind::kLoopVar && |
| env.exp(exp).kind != TensorExp::Kind::kSynZero) { |
| // Traverse into the binary operations. Note that we only hoist |
| // tensor loads, since subsequent MLIR/LLVM passes know how to |
| // deal with all other kinds of derived loop invariants. |
| if (env.exp(exp).kind == TensorExp::Kind::kReduce) |
| env.startCustomReduc(exp); // enter custom |
| const ExprId e0 = env.exp(exp).children.e0; |
| const ExprId e1 = env.exp(exp).children.e1; |
| genInvariants(env, builder, e0, curr, isStart); |
| genInvariants(env, builder, e1, curr, isStart); |
| if (env.exp(exp).kind == TensorExp::Kind::kReduce) |
| env.endCustomReduc(); // exit custom |
| } |
| } |
| |
| /// Generates an expanded access pattern in innermost dimension. |
| static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| bool isStart) { |
| linalg::GenericOp op = env.op(); |
| OpOperand *lhs = op.getDpsInitOperand(0); |
| if (!env.atExpandLevel(lhs, op.getRank(lhs), curr)) |
| return; // not needed at current level |
| assert(!env.isReduc()); |
| // Generate start or end of an expanded access pattern. Note that because |
| // an expansion does not rely on the ongoing contents of the sparse storage |
| // scheme, we can use the original tensor as incoming SSA value (which |
| // simplifies codegen a bit). If expansion on the actual contents is ever |
| // needed, we will need to use the SSA value in the insertion chain instead. |
| Value tensor = lhs->get(); |
| Location loc = op.getLoc(); |
| if (isStart) { |
| auto dynShape = {ShapedType::kDynamic}; |
| Type etp = cast<ShapedType>(tensor.getType()).getElementType(); |
| Type t1 = MemRefType::get(dynShape, etp); |
| Type t2 = MemRefType::get(dynShape, builder.getI1Type()); |
| Type t3 = MemRefType::get(dynShape, builder.getIndexType()); |
| Type t4 = builder.getIndexType(); |
| auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); |
| assert(r.getNumResults() == 4); |
| env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2), |
| r.getResult(3)); |
| } else { |
| SmallVector<Value> indices; |
| for (LoopId i = 0; i < curr; i++) |
| indices.push_back(env.emitter().getLoopIV(i)); |
| Value values = env.getExpandValues(); |
| Value filled = env.getExpandFilled(); |
| Value added = env.getExpandAdded(); |
| Value count = env.getExpandCount(); |
| Value chain = env.getInsertionChain(); |
| Value compress = builder.create<CompressOp>(loc, values, filled, added, |
| count, chain, indices); |
| env.updateInsertionChain(compress); |
| env.endExpand(); |
| } |
| } |
| |
| /// Returns parallelization strategy. Any implicit loop in the Linalg |
| /// operation that is marked "parallel" is a candidate. Whether it is actually |
| /// converted to a parallel operation depends on the requested strategy. |
| static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { |
| // Reject parallelization of sparse output. |
| if (env.hasSparseOutput()) |
| return false; |
| // Parallel loops on tensor expansion can cause data races. |
| if (env.isExpand()) |
| return false; |
| // Inspect strategy. |
| switch (env.options().parallelizationStrategy) { |
| case SparseParallelizationStrategy::kNone: |
| return false; |
| case SparseParallelizationStrategy::kDenseOuterLoop: |
| return isOuter && !isSparse; |
| case SparseParallelizationStrategy::kAnyStorageOuterLoop: |
| return isOuter; |
| case SparseParallelizationStrategy::kDenseAnyLoop: |
| return !isSparse; |
| case SparseParallelizationStrategy::kAnyStorageAnyLoop: |
| return true; |
| } |
| llvm_unreachable("unexpected parallelization strategy"); |
| } |
| |
| /// Whether or not the current loop being generated should be parallized (if |
| /// possible) according to the configuration. |
| static bool shouldTryParallize(CodegenEnv &env, LoopId curr, |
| ArrayRef<TensorLevel> tidLvls) { |
| linalg::GenericOp op = env.op(); |
| auto iteratorTypes = op.getIteratorTypesArray(); |
| bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) { |
| // Queries the LT based on the tensor and loop id, as requested by |
| // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv |
| // should be consistent with the LT indexed by <TensorId, Level>. |
| const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr); |
| return lt.hasSparseSemantic(); |
| }); |
| return isParallelFor(env, /*isOuter=*/curr == 0, isSparse); |
| } |
| |
| /// Emit a loop to coiterate over the list of tensor levels. The generated loop |
| /// can either be a for loop or while loop depending on whether there is at most |
| /// one sparse level in the list. |
| static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, |
| ArrayRef<TensorLevel> tidLvls, |
| bool tryParallel, bool needsUniv) { |
| Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { |
| // Construct while-loop with a parameter for each index. |
| return env.emitter().enterCoIterationOverTensorsAtLvls( |
| builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv); |
| }); |
| assert(loop); |
| return loop; |
| } |
| |
| /// Generates a for-loop or a while-loop, depending on whether it implements |
| /// singleton iteration or co-iteration over the given conjunction. |
| static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| bool needsUniv, ArrayRef<TensorLevel> tidLvls) { |
| bool tryParallel = shouldTryParallize(env, curr, tidLvls); |
| return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv); |
| } |
| |
| /// Generates the induction structure for a while-loop. |
| static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, |
| bool needsUniv) { |
| Location loc = env.op().getLoc(); |
| // Finalize each else branch of all if statements. |
| if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { |
| while (auto ifOp = dyn_cast_or_null<scf::IfOp>( |
| builder.getInsertionBlock()->getParentOp())) { |
| // Break on IfOp for slicing filtering. |
| if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) == |
| StringAttr::get(ifOp->getContext(), "slice")) |
| break; |
| |
| unsigned y = 0; |
| SmallVector<Value> yields; |
| if (env.isReduc()) { |
| yields.push_back(env.getReduc()); |
| env.updateReduc(ifOp.getResult(y++)); |
| if (env.isValidLexInsert()) { |
| yields.push_back(env.getValidLexInsert()); |
| env.updateValidLexInsert(ifOp.getResult(y++)); |
| } |
| } |
| if (env.isExpand()) { |
| yields.push_back(env.getExpandCount()); |
| env.updateExpandCount(ifOp->getResult(y++)); |
| } |
| if (env.getInsertionChain()) { |
| yields.push_back(env.getInsertionChain()); |
| env.updateInsertionChain(ifOp->getResult(y++)); |
| } |
| assert(y == yields.size()); |
| builder.create<scf::YieldOp>(loc, yields); |
| builder.setInsertionPointAfter(ifOp); |
| } |
| } |
| // No need to set the insertion point here as LoopEmitter keeps track of the |
| // basic block where scf::Yield should be inserted. |
| } |
| |
| /// Generates a single if-statement within a while-loop. |
| static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, |
| LatPointId p) { |
| Location loc = env.op().getLoc(); |
| SmallVector<Type> types; |
| Value cond; |
| env.merger().foreachTensorLoopId( |
| p, /*simple=*/true, |
| [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt, |
| bool isIdxRed) { |
| if (isIdxRed) { |
| // Since there is no 1:1 mapping from loop to level (multiple loops |
| // are required to resolve one level with non-trivial index |
| // expression), we need to reconstruct the tensor level types if this |
| // loop requires index reduction condition. |
| assert(lvl.has_value() && isUndefLT(lt)); |
| auto stt = getSparseTensorType(env.op().getInputs()[tid]); |
| lt = stt.getLvlType(*lvl); |
| } |
| assert(curr == env.merger().loop(b)); |
| Value clause; |
| if (lt.hasSparseSemantic()) { |
| assert(lvl.has_value()); |
| const Value crd = env.emitter().getCoord(tid, *lvl); |
| const Value lvar = env.getLoopVar(curr); |
| clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| crd, lvar); |
| } else { |
| assert(lt.hasDenseSemantic() || isUndefLT(lt)); |
| clause = constantI1(builder, loc, true); |
| } |
| cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; |
| }); |
| if (env.isReduc()) { |
| types.push_back(env.getReduc().getType()); |
| if (env.isValidLexInsert()) |
| types.push_back(env.getValidLexInsert().getType()); |
| } |
| if (env.isExpand()) |
| types.push_back(builder.getIndexType()); |
| if (env.getInsertionChain()) |
| types.push_back(env.getInsertionChain().getType()); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| return ifOp; |
| } |
| |
| /// Generates end of true branch of if-statement within a while-loop. |
| static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, |
| Value redInput, Value cntInput, Value insInput, |
| Value validIns) { |
| SmallVector<Value> operands; |
| if (env.isReduc()) { |
| operands.push_back(env.getReduc()); |
| env.updateReduc(redInput); |
| if (env.isValidLexInsert()) { |
| // Any overlapping indices during a reduction creates a valid lex insert. |
| operands.push_back(constantI1(builder, env.op().getLoc(), true)); |
| env.updateValidLexInsert(validIns); |
| } |
| } |
| if (env.isExpand()) { |
| operands.push_back(env.getExpandCount()); |
| env.updateExpandCount(cntInput); |
| } |
| if (env.getInsertionChain()) { |
| operands.push_back(env.getInsertionChain()); |
| env.updateInsertionChain(insInput); |
| } |
| if (!operands.empty()) |
| builder.create<scf::YieldOp>(env.op().getLoc(), operands); |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Sparsifier synthesis methods (loop sequence). |
| //===----------------------------------------------------------------------===// |
| |
| static bool getAllTidLvlsInLatPoints( |
| CodegenEnv &env, LatPointId li, LoopId curr, |
| llvm::function_ref<void(TensorLevel, AffineExpr)> callback) { |
| const BitVector &simple = env.lat(li).simple; |
| const TensorId outTid = env.merger().getOutTensorID(); |
| const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr); |
| |
| unsigned numloopCond = 0; |
| bool hasNonUnique = false; |
| env.merger().foreachTensorLoopId( |
| li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl, |
| LevelType lt, bool isIdxReduc) { |
| if (simple[b]) { |
| if (isIdxReduc) { |
| callback(env.makeTensorLevel(tid, *lvl), nullptr); |
| numloopCond++; |
| return; |
| } |
| if (isUndefLT(lt)) { |
| // An undefined lt in the lattices, we probably mean to |
| // generate a dense loop according to the synthetic tensor (for |
| // invariants and sparse output tensor). |
| if (env.merger().getSynTensorID() == tid) { |
| // Coiterating with an invariant |
| // e.g., out = prod(in[i][j] op invariant); |
| // or a broadcast |
| // e.g., out[i][j] = in[i] (j is undef for input) |
| // |
| // The level of the synthetic tensor is the current loop depth; |
| // the rank of the synthetic tensor equals to number of loops. |
| assert(curr == env.getCurrentDepth()); |
| lvl = curr; |
| } else if (!lvl) { |
| // Skips invalid lvl (e.g., when this is a zero ranked tensor). |
| return; |
| } |
| } |
| hasNonUnique = !isUniqueLT(lt) || hasNonUnique; |
| callback(env.makeTensorLevel(tid, *lvl), nullptr); |
| numloopCond++; |
| } else if (lt.hasDenseSemantic() || isIdxReduc) { |
| callback(env.makeTensorLevel(tid, *lvl), nullptr); |
| } else { |
| assert(isUndefLT(lt)); |
| linalg::GenericOp op = env.op(); |
| if (tid >= op.getNumDpsInputs()) |
| // We only handle affine expression on input tensors (for now). |
| return; |
| OpOperand *operand = &op->getOpOperand(tid); |
| const auto stt = getSparseTensorType(operand->get()); |
| // Non-annotated dense tensors requires no special handling. |
| if (!stt.hasEncoding()) |
| return; |
| |
| ArrayRef<AffineExpr> affines = |
| op.getMatchingIndexingMap(operand).getResults(); |
| const Level lvlRank = stt.getLvlRank(); |
| assert(affines.size() == static_cast<size_t>(lvlRank)); |
| for (Level l = 0; l < lvlRank; l++) { |
| AffineExpr exp = affines[l]; |
| // Skip simple affine expression and non-dense levels (which |
| // have their own filter loop). |
| LevelType lt = stt.getLvlType(l); |
| if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic()) |
| continue; |
| |
| // Constant affine expression are handled in genLoop. |
| if (!isa<AffineConstantExpr>(exp)) { |
| bool isCurrentLoop = false; |
| assert(curr == env.getCurrentDepth()); |
| if (isInvariantAffine(exp, curr + 1, /*out*/ isCurrentLoop) && |
| isCurrentLoop) { |
| // If the compound affine is invariant and we are right at the |
| // level. We need to generate the address according to the |
| // affine expression. This is also the best place we can do it |
| // to avoid putting it inside inner loops. |
| callback(env.makeTensorLevel(tid, l), exp); |
| } |
| } |
| } |
| } |
| }); |
| |
| if (isDenseLT(env.lt(outTid, curr))) { |
| auto stt = getSparseTensorType(env.op().getOutputs().front()); |
| // Note that we generate dense indices of the output tensor unconditionally, |
| // since they may not appear in the lattice, but may be needed for |
| // linearized env. |
| // TODO: we should avoid introducing corner cases for all-dense sparse |
| // tensors. |
| if (stt.hasEncoding() && stt.isAllDense()) |
| callback(env.makeTensorLevel(outTid, *outLvl), nullptr); |
| } |
| |
| if (numloopCond == 0) { |
| // Corner cases where the loop bound is defined by a *unused* operand, in |
| // this case, we just generate a dense "fake" loop by iterating over the |
| // synthetic tensor. |
| callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr); |
| numloopCond++; |
| } |
| // If we just need to one loop conditions and the conditions is not imposed on |
| // non-unique level, the loop can be generated by a for loop. |
| return numloopCond == 1 && !hasNonUnique; |
| } |
| |
| /// Starts a loop sequence at given level. Returns true if |
| /// the universal loop index must be maintained at this level. |
| static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, |
| LoopId curr, LatSetId lts) { |
| assert(!env.getLoopVar(curr)); |
| // Emit invariants at this loop sequence level. |
| genInvariants(env, builder, exp, curr, /*isStart=*/true); |
| // Emit access pattern expansion for sparse tensor output. |
| genExpand(env, builder, curr, /*isStart=*/true); |
| // Emit further initialization at this loop sequence level. |
| const LatPointId l0 = env.set(lts)[0]; |
| |
| SmallVector<TensorLevel> tidLvls; |
| getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) { |
| // TODO: remove this! The same tensor level might be added for multiple |
| // times due to the special handling for all-dense "sparse" output tensor |
| // (see L1038). |
| if (llvm::find(tidLvls, tl) != tidLvls.end()) |
| return; |
| tidLvls.emplace_back(tl); |
| }); |
| |
| env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls); |
| |
| // Maintain the universal index only if it is actually |
| // consumed by a subsequent lattice point. |
| for (const LatPointId li : env.set(lts).drop_front()) |
| if (!env.merger().hasAnySparse(env.lat(li).simple)) |
| return true; |
| |
| return false; |
| } |
| |
| // Generates dense affine address for encoding. |
| static void genConstantDenseAddressFromLevel(CodegenEnv &env, |
| OpBuilder &builder, TensorId tid, |
| Level startLvl) { |
| // TODO: Handle affine expression on output tensor. |
| linalg::GenericOp op = env.op(); |
| assert(tid < op.getNumDpsInputs()); |
| OpOperand *input = op.getDpsInputOperands()[tid]; |
| const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); |
| const auto enc = getSparseTensorEncoding(input->get().getType()); |
| if (enc) { |
| const Location loc = op.getLoc(); |
| const TensorId tid = env.makeTensorId(input->getOperandNumber()); |
| const Level lvlRank = enc.getLvlRank(); |
| assert(lvlExprs.size() == static_cast<size_t>(lvlRank)); |
| for (Level l = startLvl; l < lvlRank; l++) { |
| AffineExpr lvlExpr = lvlExprs[l]; |
| if (enc.getLvlType(l).hasDenseSemantic() && |
| isa<AffineConstantExpr>(lvlExpr)) |
| env.emitter().locateLvlAtAffineAddress( |
| builder, loc, env.makeTensorLevel(tid, l), lvlExpr); |
| else |
| return; // break on first non-dense non-constant level |
| } |
| } |
| } |
| |
| // We can generate address for constant affine expression before any loops |
| // starting from the first level as they do not depend on anything. |
| // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two |
| // levels can be determined before loops. |
| static void genInitConstantDenseAddress(CodegenEnv &env, |
| RewriterBase &rewriter) { |
| for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) |
| genConstantDenseAddressFromLevel(env, rewriter, tid, 0); |
| } |
| |
| /// Returns true if the lattice bit can be iterated by a for loop. |
| static bool translateBitsToTidLvlPairs( |
| CodegenEnv &env, LatPointId li, LoopId curr, |
| SmallVectorImpl<TensorLevel> &tidLvls, |
| SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) { |
| return getAllTidLvlsInLatPoints(env, li, curr, |
| [&](TensorLevel tl, AffineExpr exp) { |
| if (exp) |
| affineTidLvls.emplace_back(tl, exp); |
| else |
| tidLvls.emplace_back(tl); |
| }); |
| } |
| |
| /// Starts a single loop in current sequence. |
| static std::pair<Operation *, bool> startLoop(CodegenEnv &env, |
| OpBuilder &builder, LoopId curr, |
| LatPointId li, bool needsUniv) { |
| // The set of tensors + lvls to generate loops on |
| SmallVector<TensorLevel> tidLvls; |
| |
| // The set of dense tensors with non-trivial affine expression that just |
| // becomes invariant and the address are generated at the current level. |
| SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls; |
| bool isSingleCond = |
| translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); |
| |
| // Emit the for/while-loop control. |
| Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls); |
| Location loc = env.op().getLoc(); |
| for (auto [tidLvl, exp] : affineTidLvls) { |
| env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp); |
| } |
| |
| // Until now, we have entered every <tid, lvl> pair in {cond, extra, |
| // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent |
| // on constant affines expression may now be determined. |
| auto allTidLvls = |
| llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls)); |
| for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) { |
| if (tid != env.merger().getOutTensorID() && |
| tid != env.merger().getSynTensorID()) |
| genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); |
| } |
| |
| return std::make_pair(loop, isSingleCond); |
| } |
| |
| /// Ends a single loop in current sequence. Returns new values for needsUniv. |
| static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, |
| LatPointId li, bool needsUniv, bool isSingleCond) { |
| // Either a for-loop or a while-loop that iterates over a slice. |
| if (isSingleCond) { |
| // Any iteration creates a valid lex insert. |
| if (env.isReduc() && env.isValidLexInsert()) |
| env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); |
| } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { |
| // End a while-loop. |
| finalizeWhileOp(env, rewriter, needsUniv); |
| } else { |
| needsUniv = false; |
| } |
| env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { |
| env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc); |
| return std::nullopt; |
| }); |
| return needsUniv; |
| } |
| |
| /// Ends a loop sequence at given level. |
| static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, |
| unsigned at) { |
| assert(!env.getLoopVar(at)); |
| env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc()); |
| // Unmark bookkeeping of invariants and loop index. |
| genInvariants(env, builder, exp, at, /*isStart=*/false); |
| // Finalize access pattern expansion for sparse tensor output. |
| genExpand(env, builder, at, /*isStart=*/false); |
| } |
| |
| /// Recursively generates code while computing iteration lattices in order |
| /// to manage the complexity of implementing co-iteration over unions |
| /// and intersections of sparse iterations spaces. |
| static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, |
| LoopId curr) { |
| assert(curr == env.getCurrentDepth()); |
| |
| // At each leaf, assign remaining tensor (sub)expression to output tensor. |
| if (curr == env.getLoopNum()) { |
| Value rhs = genExp(env, rewriter, exp); |
| genTensorStore(env, rewriter, exp, rhs); |
| return; |
| } |
| |
| // Construct iteration lattices for current loop index. |
| const LatSetId lts = |
| env.merger().optimizeSet(env.merger().buildLattices(exp, curr)); |
| |
| // Start a loop sequence. |
| bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts); |
| |
| // Emit a loop for every lattice point L0 >= Li in this loop sequence. |
| // We cannot change this to `for (const LatPointId li : env.set(lts))` |
| // because the loop body causes data-movement which invalidates |
| // the iterator. |
| const unsigned lsize = env.set(lts).size(); |
| for (unsigned i = 0; i < lsize; i++) { |
| const LatPointId li = env.set(lts)[i]; |
| // Start a loop. |
| auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv); |
| |
| // Visit all lattices points with Li >= Lj to generate the |
| // loop-body, possibly with if statements for coiteration. |
| Value redInput = env.getReduc(); |
| Value cntInput = env.getExpandCount(); |
| Value insInput = env.getInsertionChain(); |
| Value validIns = env.getValidLexInsert(); |
| // We cannot change this to `for (const LatPointId lj : env.set(lts))` |
| // because the loop body causes data-movement which invalidates the |
| // iterator. |
| for (unsigned j = 0; j < lsize; j++) { |
| const LatPointId lj = env.set(lts)[j]; |
| const ExprId ej = env.lat(lj).exp; |
| if (li == lj || env.merger().latGT(li, lj)) { |
| // Recurse into body of each branch. |
| if (!isSingleCond) { |
| scf::IfOp ifOp = genIf(env, rewriter, curr, lj); |
| genStmt(env, rewriter, ej, curr + 1); |
| endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); |
| } else { |
| genStmt(env, rewriter, ej, curr + 1); |
| } |
| } |
| } |
| |
| // End a loop. |
| needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond); |
| } |
| |
| // End a loop sequence. |
| endLoopSeq(env, rewriter, exp, curr); |
| assert(curr == env.getCurrentDepth()); |
| } |
| |
| /// Converts the result computed by the sparse kernel into the required form. |
| static void genResult(CodegenEnv &env, RewriterBase &rewriter) { |
| linalg::GenericOp op = env.op(); |
| OpOperand *lhs = op.getDpsInitOperand(0); |
| Value tensor = lhs->get(); |
| Type resType = tensor.getType(); |
| if (getSparseTensorEncoding(resType)) { |
| // The sparse tensor rematerializes from the original sparse tensor's |
| // underlying sparse storage format. For an insertion chain, the |
| // tensor materializes from the chain with 'hasInserts' enabled. |
| bool hasInserts = false; |
| if (Value chain = env.getInsertionChain()) { |
| hasInserts = true; |
| tensor = chain; |
| } |
| rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts); |
| } else { |
| // To rematerialize an non-annotated tensor, simply load it |
| // from the bufferized value. |
| Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()]; |
| rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Sparsifier rewriting methods. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Sparse rewriting rule for generic Lingalg operation. |
| struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { |
| public: |
| GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) |
| : OpRewritePattern<linalg::GenericOp>(context), options(o) {} |
| |
| LogicalResult matchAndRewrite(linalg::GenericOp op, |
| PatternRewriter &rewriter) const override { |
| // Only accept single output operations with pure tensor semantics. |
| if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics()) |
| return failure(); |
| |
| // Only accept trivial affine indices. |
| if (hasNonTrivialAffineOnSparseOut(op)) |
| return failure(); |
| |
| // Only accept scheduled loops. |
| if (!op->hasAttr("sorted")) { |
| return rewriter.notifyMatchFailure( |
| op, "Loops not yet scheduled, try run --sparse-reinterpret-map " |
| "before sparsification."); |
| } |
| |
| // Must have been demapped as well if the generic op is sorted. |
| assert(!hasAnyNonIdentityOperandsOrResults(op)); |
| |
| // Sets up a code generation environment. |
| const unsigned numTensors = op->getNumOperands(); |
| const unsigned numLoops = op.getNumLoops(); |
| bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0; |
| // If we have indexing map like (d0) -> (0, d0), there might be more |
| // levels then loops because of the constant index, that means we can not |
| // use numLoops as the upper bound for ranks of all tensors. |
| // TODO: Constant indices are currently not support on sparse tensor, but |
| // are allowed in non-annotated dense tensor. Support it, it would be |
| // required for sparse tensor slice rank reducing too. |
| Level maxLvlRank = 0; |
| for (auto operand : op.getOperands()) { |
| if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) { |
| maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); |
| } |
| } |
| |
| // Detects sparse annotations and translates the per-level sparsity |
| // information for all tensors to loop indices in the kernel. |
| CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); |
| if (!findSparseAnnotations(env, needIdxRed)) |
| return failure(); |
| |
| // Only standard reduction operations (add, sub, or, xor) that can be |
| // sparsified by merely reducing the stored values are admissible. More |
| // elaborate reduction operations (such as mul, and, min, max) would need |
| // to know whether implicit zeros occur as well. They can still be |
| // implemented with a custom reduction operation, accepted here as well. |
| if (op.getNumReductionLoops() > 0) { |
| Operation *yield = op.getRegion().front().getTerminator(); |
| assert(isa<linalg::YieldOp>(yield)); |
| Operation *redop = yield->getOperand(0).getDefiningOp(); |
| if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) && |
| !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) && |
| !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) && |
| !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) && |
| !isa<ReduceOp>(redop)) { |
| return failure(); |
| } |
| } |
| |
| // Constructs the tensor expressions tree from `op`, returns failure if the |
| // tree can not be built or the tensor expression is inadmissible. |
| if (failed(env.initTensorExp())) |
| return failure(); |
| |
| // Recursively generates code if admissible. |
| env.startEmit(options.sparseEmitStrategy); |
| genBuffers(env, rewriter); |
| // TODO: Constant affine expression should be handled differently when using |
| // slice-based codegen, it does not matter now because we already reject the |
| // constant expression at an earlier stage. |
| genInitConstantDenseAddress(env, rewriter); |
| genStmt(env, rewriter, env.getExprId(), 0); |
| genResult(env, rewriter); |
| return success(); |
| } |
| |
| private: |
| /// Options to control sparse code generation. |
| SparsificationOptions options; |
| }; |
| |
| } // namespace |
| |
| /// Populates the given patterns list with rewriting rules required for |
| /// the sparsification of linear algebra operations. |
| void mlir::populateSparsificationPatterns( |
| RewritePatternSet &patterns, const SparsificationOptions &options) { |
| patterns.add<GenericOpSparsifier>(patterns.getContext(), options); |
| } |