| //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements rewriting rules that are specific to sparse tensor |
| // primitives with memref operands. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Utils/CodegenUtils.h" |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| #include "mlir/Support/LLVM.h" |
| |
| using namespace mlir; |
| using namespace mlir::sparse_tensor; |
| |
| //===---------------------------------------------------------------------===// |
| // Helper methods for the actual rewriting rules. |
| //===---------------------------------------------------------------------===// |
| |
| static constexpr uint64_t loIdx = 0; |
| static constexpr uint64_t hiIdx = 1; |
| static constexpr uint64_t xStartIdx = 2; |
| |
| static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; |
| static constexpr const char kBinarySearchFuncNamePrefix[] = |
| "_sparse_binary_search_"; |
| static constexpr const char kHybridQuickSortFuncNamePrefix[] = |
| "_sparse_hybrid_qsort_"; |
| static constexpr const char kSortStableFuncNamePrefix[] = |
| "_sparse_sort_stable_"; |
| static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; |
| static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; |
| static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; |
| |
| using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, |
| AffineMap, uint64_t, uint32_t)>; |
| |
| /// Constructs a function name with this format to facilitate quick sort: |
| /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort |
| /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo |
| static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, |
| StringRef namePrefix, AffineMap xPerm, |
| uint64_t ny, ValueRange operands) { |
| nameOstream << namePrefix; |
| for (auto res : xPerm.getResults()) |
| nameOstream << cast<AffineDimExpr>(res).getPosition() << "_"; |
| |
| nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); |
| nameOstream << "_coo_" << ny; |
| |
| constexpr uint64_t yBufferOffset = 1; |
| for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) |
| nameOstream << "_" << getMemRefType(v).getElementType(); |
| } |
| |
| /// Looks up a function that is appropriate for the given operands being |
| /// sorted, and creates such a function if it doesn't exist yet. The |
| /// parameters `xPerm` and `ny` tell the number of x and y values provided |
| /// by the buffer in xStartIdx. |
| // |
| // All sorting function generators take (lo, hi, xs, ys) in `operands` as |
| // parameters for the sorting functions. Other parameters, such as the recursive |
| // call depth, are appended to the end of the parameter list as |
| // "trailing parameters". |
| static FlatSymbolRefAttr getMangledSortHelperFunc( |
| OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, |
| StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, |
| FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { |
| SmallString<32> nameBuffer; |
| llvm::raw_svector_ostream nameOstream(nameBuffer); |
| getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, |
| operands.drop_back(nTrailingP)); |
| |
| ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); |
| MLIRContext *context = module.getContext(); |
| auto result = SymbolRefAttr::get(context, nameOstream.str()); |
| auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); |
| |
| if (!func) { |
| // Create the function. |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| builder.setInsertionPoint(insertPoint); |
| Location loc = insertPoint.getLoc(); |
| func = builder.create<func::FuncOp>( |
| loc, nameOstream.str(), |
| FunctionType::get(context, operands.getTypes(), resultTypes)); |
| func.setPrivate(); |
| createFunc(builder, module, func, xPerm, ny, nTrailingP); |
| } |
| |
| return result; |
| } |
| |
| /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. |
| /// The code to process the value pairs is generated by `bodyBuilder`. |
| static void forEachIJPairInXs( |
| OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
| uint64_t ny, |
| function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { |
| Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny); |
| Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); |
| Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); |
| for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { |
| unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition(); |
| Value ak = constantIndex(builder, loc, actualK); |
| Value i = builder.create<arith::AddIOp>(loc, ak, iOffset); |
| Value j = builder.create<arith::AddIOp>(loc, ak, jOffset); |
| Value buffer = args[xStartIdx]; |
| |
| bodyBuilder(k, i, j, buffer); |
| } |
| } |
| |
| /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. |
| /// The code to process the value pairs is generated by `bodyBuilder`. |
| static void forEachIJPairInAllBuffers( |
| OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
| uint64_t ny, |
| function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { |
| |
| // Create code for the first (xPerm + ny) buffers. |
| SmallVector<AffineExpr> exps(xPerm.getResults().begin(), |
| xPerm.getResults().end()); |
| for (unsigned y = 0; y < ny; y++) { |
| exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults())); |
| } |
| AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext()); |
| assert(xyPerm.isPermutation()); |
| |
| forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder); |
| |
| constexpr uint64_t numHandledBuffers = 1; |
| // Create code for the remaining buffers. |
| Value i = args[0]; |
| Value j = args[1]; |
| for (const auto &arg : |
| llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { |
| bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); |
| } |
| } |
| |
| /// Creates a code block for swapping the values in index i and j for all the |
| /// buffers. |
| // |
| // The generated IR corresponds to this C like algorithm: |
| // swap(x0[i], x0[j]); |
| // swap(x1[i], x1[j]); |
| // ... |
| // swap(xn[i], xn[j]); |
| // swap(y0[i], y0[j]); |
| // ... |
| // swap(yn[i], yn[j]); |
| static void createSwap(OpBuilder &builder, Location loc, ValueRange args, |
| AffineMap xPerm, uint64_t ny) { |
| auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { |
| Value vi = builder.create<memref::LoadOp>(loc, buffer, i); |
| Value vj = builder.create<memref::LoadOp>(loc, buffer, j); |
| builder.create<memref::StoreOp>(loc, vj, buffer, i); |
| builder.create<memref::StoreOp>(loc, vi, buffer, j); |
| }; |
| |
| forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair); |
| } |
| |
| /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare |
| /// each pair is create via `compareBuilder`. |
| static Value createInlinedCompareImplementation( |
| OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, |
| uint64_t ny, |
| function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> |
| compareBuilder) { |
| Value result; |
| auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { |
| bool isFirstDim = (k == 0); |
| bool isLastDim = (k == xPerm.getNumResults() - 1); |
| Value val = |
| compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); |
| if (isFirstDim) { |
| result = val; |
| } else if (!isLastDim) { |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| auto ifOp = cast<scf::IfOp>(val.getDefiningOp()); |
| builder.setInsertionPointAfter(ifOp); |
| builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
| } |
| }; |
| |
| forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); |
| |
| builder.setInsertionPointAfterValue(result); |
| return result; |
| } |
| |
| /// Generates code to compare whether x[i] is equal to x[j] and returns the |
| /// result of the comparison. |
| static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, |
| Value x, bool isFirstDim, bool isLastDim) { |
| Value vi = builder.create<memref::LoadOp>(loc, x, i); |
| Value vj = builder.create<memref::LoadOp>(loc, x, j); |
| |
| Value res; |
| if (isLastDim) { |
| res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); |
| // For 1D, we create a compare without any control flow. Otherwise, we |
| // create YieldOp to return the result in the nested if-stmt. |
| if (!isFirstDim) |
| builder.create<scf::YieldOp>(loc, res); |
| } else { |
| Value ne = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), |
| ne, /*else=*/true); |
| // If (x[i] != x[j]). |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| Value f = constantI1(builder, loc, false); |
| builder.create<scf::YieldOp>(loc, f); |
| |
| // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that |
| // checks the remaining dimensions. |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| res = ifOp.getResult(0); |
| } |
| |
| return res; |
| } |
| |
| /// Creates code to compare whether xs[i] is equal to xs[j]. |
| // |
| // The generate IR corresponds to this C like algorithm: |
| // if (x0[i] != x0[j]) |
| // return false; |
| // else |
| // if (x1[i] != x1[j]) |
| // return false; |
| // else if (x2[2] != x2[j])) |
| // and so on ... |
| static Value createInlinedEqCompare(OpBuilder &builder, Location loc, |
| ValueRange args, AffineMap xPerm, |
| uint64_t ny, uint32_t nTrailingP = 0) { |
| // Compare functions don't use trailing parameters. |
| (void)nTrailingP; |
| assert(nTrailingP == 0); |
| return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, |
| createEqCompare); |
| } |
| |
| /// Generates code to compare whether x[i] is less than x[j] and returns the |
| /// result of the comparison. |
| static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, |
| Value j, Value x, bool isFirstDim, |
| bool isLastDim) { |
| Value vi = builder.create<memref::LoadOp>(loc, x, i); |
| Value vj = builder.create<memref::LoadOp>(loc, x, j); |
| |
| Value res; |
| if (isLastDim) { |
| res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); |
| // For 1D, we create a compare without any control flow. Otherwise, we |
| // create YieldOp to return the result in the nested if-stmt. |
| if (!isFirstDim) |
| builder.create<scf::YieldOp>(loc, res); |
| } else { |
| Value ne = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), |
| ne, /*else=*/true); |
| // If (x[i] != x[j]). |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| Value lt = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); |
| builder.create<scf::YieldOp>(loc, lt); |
| |
| // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that |
| // checks the remaining dimensions. |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| res = ifOp.getResult(0); |
| } |
| |
| return res; |
| } |
| |
| /// Creates code to compare whether xs[i] is less than xs[j]. |
| // |
| // The generate IR corresponds to this C like algorithm: |
| // if (x0[i] != x0[j]) |
| // return x0[i] < x0[j]; |
| // else if (x1[j] != x1[i]) |
| // return x1[i] < x1[j]; |
| // else |
| // and so on ... |
| static Value createInlinedLessThan(OpBuilder &builder, Location loc, |
| ValueRange args, AffineMap xPerm, |
| uint64_t ny, uint32_t nTrailingP = 0) { |
| // Compare functions don't use trailing parameters. |
| (void)nTrailingP; |
| assert(nTrailingP == 0); |
| return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, |
| createLessThanCompare); |
| } |
| |
| /// Creates a function to use a binary search to find the insertion point for |
| /// inserting xs[hi] to the sorted values xs[lo..hi). |
| // |
| // The generate IR corresponds to this C like algorithm: |
| // p = hi |
| // while (lo < hi) |
| // mid = (lo + hi) >> 1 |
| // if (xs[p] < xs[mid]) |
| // hi = mid |
| // else |
| // lo = mid - 1 |
| // return lo; |
| // |
| static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, |
| func::FuncOp func, AffineMap xPerm, |
| uint64_t ny, uint32_t nTrailingP = 0) { |
| // Binary search doesn't use trailing parameters. |
| (void)nTrailingP; |
| assert(nTrailingP == 0); |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| Block *entryBlock = func.addEntryBlock(); |
| builder.setInsertionPointToStart(entryBlock); |
| |
| Location loc = func.getLoc(); |
| ValueRange args = entryBlock->getArguments(); |
| Value p = args[hiIdx]; |
| SmallVector<Type, 2> types(2, p.getType()); // Only two types. |
| scf::WhileOp whileOp = builder.create<scf::WhileOp>( |
| loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); |
| |
| // The before-region of the WhileOp. |
| Block *before = |
| builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); |
| builder.setInsertionPointToEnd(before); |
| Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
| before->getArgument(0), |
| before->getArgument(1)); |
| builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); |
| |
| // The after-region of the WhileOp. |
| Block *after = |
| builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); |
| builder.setInsertionPointToEnd(after); |
| Value lo = after->getArgument(0); |
| Value hi = after->getArgument(1); |
| // Compute mid = (lo + hi) >> 1. |
| Value c1 = constantIndex(builder, loc, 1); |
| Value mid = builder.create<arith::ShRUIOp>( |
| loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); |
| Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); |
| |
| // Compare xs[p] < xs[mid]. |
| SmallVector<Value> compareOperands{p, mid}; |
| constexpr uint64_t numXBuffers = 1; |
| compareOperands.append(args.begin() + xStartIdx, |
| args.begin() + xStartIdx + numXBuffers); |
| Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); |
| // Update lo and hi for the WhileOp as follows: |
| // if (xs[p] < xs[mid])) |
| // hi = mid; |
| // else |
| // lo = mid + 1; |
| Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); |
| Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); |
| builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); |
| |
| builder.setInsertionPointAfter(whileOp); |
| builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); |
| } |
| |
| /// Creates code to advance i in a loop based on xs[p] as follows: |
| /// while (xs[i] < xs[p]) i += step (step > 0) |
| /// or |
| /// while (xs[i] > xs[p]) i += step (step < 0) |
| /// The routine returns i as well as a boolean value to indicate whether |
| /// xs[i] == xs[p]. |
| static std::pair<Value, Value> createScanLoop(OpBuilder &builder, |
| ModuleOp module, |
| func::FuncOp func, ValueRange xs, |
| Value i, Value p, AffineMap xPerm, |
| uint64_t ny, int step) { |
| Location loc = func.getLoc(); |
| scf::WhileOp whileOp = |
| builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); |
| |
| Block *before = |
| builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); |
| builder.setInsertionPointToEnd(before); |
| SmallVector<Value> compareOperands; |
| if (step > 0) { |
| compareOperands.push_back(before->getArgument(0)); |
| compareOperands.push_back(p); |
| } else { |
| assert(step < 0); |
| compareOperands.push_back(p); |
| compareOperands.push_back(before->getArgument(0)); |
| } |
| compareOperands.append(xs.begin(), xs.end()); |
| Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); |
| builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); |
| |
| Block *after = |
| builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); |
| builder.setInsertionPointToEnd(after); |
| Value cs = constantIndex(builder, loc, step); |
| i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); |
| builder.create<scf::YieldOp>(loc, ValueRange{i}); |
| i = whileOp.getResult(0); |
| |
| builder.setInsertionPointAfter(whileOp); |
| compareOperands[0] = i; |
| compareOperands[1] = p; |
| Value compareEq = |
| createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny); |
| |
| return std::make_pair(whileOp.getResult(0), compareEq); |
| } |
| |
| /// Creates and returns an IfOp to compare two elements and swap the elements |
| /// if compareFunc(data[b], data[a]) returns true. The new insertion point is |
| /// right after the swap instructions. |
| static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, |
| AffineMap xPerm, uint64_t ny, |
| SmallVectorImpl<Value> &swapOperands, |
| SmallVectorImpl<Value> &compareOperands, |
| Value a, Value b) { |
| // Compare(data[b], data[a]). |
| compareOperands[0] = b; |
| compareOperands[1] = a; |
| Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| swapOperands[0] = b; |
| swapOperands[1] = a; |
| createSwap(builder, loc, swapOperands, xPerm, ny); |
| return ifOp; |
| } |
| |
| /// Creates code to insert the 3rd element to a list of two sorted elements. |
| static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, |
| uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
| SmallVectorImpl<Value> &compareOperands, Value v0, |
| Value v1, Value v2) { |
| scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
| compareOperands, v1, v2); |
| createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, |
| v0, v1); |
| builder.setInsertionPointAfter(ifOp); |
| } |
| |
| /// Creates code to sort 3 elements. |
| static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, |
| uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
| SmallVectorImpl<Value> &compareOperands, Value v0, |
| Value v1, Value v2) { |
| // Sort the first 2 elements. |
| scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
| compareOperands, v0, v1); |
| builder.setInsertionPointAfter(ifOp1); |
| |
| // Insert the 3th element. |
| createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, |
| v1, v2); |
| } |
| |
| /// Creates code to sort 5 elements. |
| static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, |
| uint64_t ny, SmallVectorImpl<Value> &swapOperands, |
| SmallVectorImpl<Value> &compareOperands, Value v0, |
| Value v1, Value v2, Value v3, Value v4) { |
| // Sort the first 3 elements. |
| createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, |
| v2); |
| |
| auto insert4th = [&]() { |
| scf::IfOp ifOp = createCompareThenSwap( |
| builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); |
| createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, |
| v1, v2); |
| builder.setInsertionPointAfter(ifOp); |
| }; |
| |
| // Insert the 4th element. |
| insert4th(); |
| |
| // Insert the 5th element. |
| scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, |
| compareOperands, v3, v4); |
| insert4th(); |
| builder.setInsertionPointAfter(ifOp); |
| } |
| |
| /// Creates a code block to swap the values in indices lo, mi, and hi so that |
| /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When |
| /// the number of values in range [lo, hi) is more than a threshold, we also |
| /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. |
| static void createChoosePivot(OpBuilder &builder, ModuleOp module, |
| func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| Value lo, Value hi, Value mi, ValueRange args) { |
| SmallVector<Value> compareOperands{mi, lo}; |
| constexpr uint64_t numXBuffers = 1; |
| compareOperands.append(args.begin() + xStartIdx, |
| args.begin() + xStartIdx + numXBuffers); |
| SmallVector<Value> swapOperands{mi, lo}; |
| swapOperands.append(args.begin() + xStartIdx, args.end()); |
| Location loc = func.getLoc(); |
| Value c1 = constantIndex(builder, loc, 1); |
| Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1); |
| Value len = builder.create<arith::SubIOp>(loc, hiP1, lo); |
| Value lenThreshold = constantIndex(builder, loc, 1000); |
| Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
| len, lenThreshold); |
| scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); |
| |
| // When len < 1000, choose pivot from median of 3 values. |
| builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); |
| createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi, |
| hi); |
| |
| // When len >= 1000, choose pivot from median of 5 values. |
| builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); |
| Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1); |
| Value a = builder.create<arith::AddIOp>(loc, lo, miP1); |
| // Value a is the middle between [loc, mi]. |
| a = builder.create<arith::ShRUIOp>(loc, a, c1); |
| Value b = builder.create<arith::AddIOp>(loc, mi, hiP1); |
| // Value b is the middle between [mi, hi]. |
| b = builder.create<arith::ShRUIOp>(loc, b, c1); |
| createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi, |
| b, hi); |
| |
| builder.setInsertionPointAfter(lenIf); |
| } |
| |
| /// Creates a function to perform quick sort partition on the values in the |
| /// range of index [lo, hi), assuming lo < hi. |
| // |
| // The generated IR corresponds to this C like algorithm: |
| // int partition(lo, hi, xs) { |
| // p = (lo+hi)/2 // pivot index |
| // i = lo |
| // j = hi-1 |
| // while (true) do { |
| // while (xs[i] < xs[p]) i ++; |
| // i_eq = (xs[i] == xs[p]); |
| // while (xs[j] > xs[p]) j --; |
| // j_eq = (xs[j] == xs[p]); |
| // |
| // if (i >= j) return j + 1; |
| // |
| // if (i < j) { |
| // swap(xs[i], xs[j]) |
| // if (i == p) { |
| // p = j; |
| // } else if (j == p) { |
| // p = i; |
| // } |
| // if (i_eq && j_eq) { |
| // ++i; |
| // --j; |
| // } |
| // } |
| // } |
| // } |
| static void createPartitionFunc(OpBuilder &builder, ModuleOp module, |
| func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| uint32_t nTrailingP = 0) { |
| // Quick sort partition doesn't use trailing parameters. |
| (void)nTrailingP; |
| assert(nTrailingP == 0); |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| |
| Block *entryBlock = func.addEntryBlock(); |
| builder.setInsertionPointToStart(entryBlock); |
| |
| Location loc = func.getLoc(); |
| ValueRange args = entryBlock->getArguments(); |
| Value lo = args[loIdx]; |
| Value hi = args[hiIdx]; |
| Value sum = builder.create<arith::AddIOp>(loc, lo, hi); |
| Value c1 = constantIndex(builder, loc, 1); |
| Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); |
| |
| Value i = lo; |
| Value j = builder.create<arith::SubIOp>(loc, hi, c1); |
| createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); |
| Value trueVal = constantI1(builder, loc, true); // The value for while (true) |
| SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values. |
| SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(), |
| trueVal.getType()}; |
| scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); |
| |
| // The before-region of the WhileOp. |
| Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, |
| {loc, loc, loc, loc}); |
| builder.setInsertionPointToEnd(before); |
| builder.create<scf::ConditionOp>(loc, before->getArgument(3), |
| before->getArguments()); |
| |
| // The after-region of the WhileOp. |
| Block *after = |
| builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); |
| builder.setInsertionPointToEnd(after); |
| i = after->getArgument(0); |
| j = after->getArgument(1); |
| p = after->getArgument(2); |
| |
| constexpr uint64_t numXBuffers = 1; |
| auto [iresult, iCompareEq] = |
| createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), |
| i, p, xPerm, ny, 1); |
| i = iresult; |
| auto [jresult, jCompareEq] = |
| createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), |
| j, p, xPerm, ny, -1); |
| j = jresult; |
| |
| // If i < j: |
| Value cond = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| SmallVector<Value> swapOperands{i, j}; |
| swapOperands.append(args.begin() + xStartIdx, args.end()); |
| createSwap(builder, loc, swapOperands, xPerm, ny); |
| // If the pivot is moved, update p with the new pivot. |
| Value icond = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); |
| scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, |
| icond, /*else=*/true); |
| builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{j}); |
| builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); |
| Value jcond = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); |
| scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, |
| jcond, /*else=*/true); |
| builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{i}); |
| builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{p}); |
| builder.setInsertionPointAfter(ifOpJ); |
| builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); |
| builder.setInsertionPointAfter(ifOpI); |
| Value compareEqIJ = |
| builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); |
| scf::IfOp ifOp2 = builder.create<scf::IfOp>( |
| loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); |
| builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); |
| Value i2 = builder.create<arith::AddIOp>(loc, i, c1); |
| Value j2 = builder.create<arith::SubIOp>(loc, j, c1); |
| builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); |
| builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{i, j}); |
| builder.setInsertionPointAfter(ifOp2); |
| builder.create<scf::YieldOp>( |
| loc, |
| ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), |
| /*cont=*/constantI1(builder, loc, true)}); |
| |
| // False branch for if i < j (i.e., i >= j): |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| p = builder.create<arith::AddIOp>(loc, j, |
| constantOne(builder, loc, j.getType())); |
| builder.create<scf::YieldOp>( |
| loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); |
| |
| // Return for the whileOp. |
| builder.setInsertionPointAfter(ifOp); |
| builder.create<scf::YieldOp>(loc, ifOp.getResults()); |
| |
| // Return for the function. |
| builder.setInsertionPointAfter(whileOp); |
| builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); |
| } |
| |
| /// Computes (n-2)/n, assuming n has index type. |
| static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, |
| Value n) { |
| Value i2 = constantIndex(builder, loc, 2); |
| Value res = builder.create<arith::SubIOp>(loc, n, i2); |
| Value i1 = constantIndex(builder, loc, 1); |
| return builder.create<arith::ShRUIOp>(loc, res, i1); |
| } |
| |
| /// Creates a function to heapify the subtree with root `start` within the full |
| /// binary tree in the range of index [first, first + n). |
| // |
| // The generated IR corresponds to this C like algorithm: |
| // void shiftDown(first, start, n, data) { |
| // if (n >= 2) { |
| // child = start - first |
| // if ((n-2)/2 >= child) { |
| // // Left child exists. |
| // child = child * 2 + 1 // Initialize the bigger child to left child. |
| // childIndex = child + first |
| // if (child+1 < n && data[childIndex] < data[childIndex+1]) |
| // // Right child exits and is bigger. |
| // childIndex++; child++; |
| // // Shift data[start] down to where it belongs in the subtree. |
| // while (data[start] < data[childIndex) { |
| // swap(data[start], data[childIndex]) |
| // start = childIndex |
| // if ((n - 2)/2 >= child) { |
| // // Left child exists. |
| // child = 2*child + 1 |
| // childIndex = child + 1 |
| // if (child + 1) < n && data[childIndex] < data[childIndex+1] |
| // childIndex++; child++; |
| // } |
| // } |
| // } |
| // } |
| // } |
| // |
| static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, |
| func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| uint32_t nTrailingP) { |
| // The value n is passed in as a trailing parameter. |
| assert(nTrailingP == 1); |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| Block *entryBlock = func.addEntryBlock(); |
| builder.setInsertionPointToStart(entryBlock); |
| |
| Location loc = func.getLoc(); |
| Value n = entryBlock->getArguments().back(); |
| ValueRange args = entryBlock->getArguments().drop_back(); |
| Value first = args[loIdx]; |
| Value start = args[hiIdx]; |
| |
| // If (n >= 2). |
| Value c2 = constantIndex(builder, loc, 2); |
| Value condN = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); |
| scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); |
| builder.setInsertionPointToStart(&ifN.getThenRegion().front()); |
| Value child = builder.create<arith::SubIOp>(loc, start, first); |
| |
| // If ((n-2)/2 >= child). |
| Value t = createSubTwoDividedByTwo(builder, loc, n); |
| Value condNc = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); |
| scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); |
| |
| builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); |
| Value c1 = constantIndex(builder, loc, 1); |
| SmallVector<Value> compareOperands{start, start}; |
| constexpr uint64_t numXBuffers = 1; |
| compareOperands.append(args.begin() + xStartIdx, |
| args.begin() + xStartIdx + numXBuffers); |
| |
| // Generate code to inspect the children of 'r' and return the larger child |
| // as follows: |
| // child = r * 2 + 1 // Left child. |
| // childIndex = child + first |
| // if (child+1 < n && data[childIndex] < data[childIndex+1]) |
| // childIndex ++; child ++ // Right child is bigger. |
| auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { |
| Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); |
| lChild = builder.create<arith::AddIOp>(loc, lChild, c1); |
| Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); |
| Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); |
| Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
| rChild, n); |
| SmallVector<Type, 2> ifTypes(2, r.getType()); |
| scf::IfOp if1 = |
| builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); |
| builder.setInsertionPointToStart(&if1.getThenRegion().front()); |
| Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); |
| // Compare data[left] < data[right]. |
| compareOperands[0] = lChildIdx; |
| compareOperands[1] = rChildIdx; |
| Value cond2 = |
| createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); |
| scf::IfOp if2 = |
| builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); |
| builder.setInsertionPointToStart(&if2.getThenRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); |
| builder.setInsertionPointToStart(&if2.getElseRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); |
| builder.setInsertionPointAfter(if2); |
| builder.create<scf::YieldOp>(loc, if2.getResults()); |
| builder.setInsertionPointToStart(&if1.getElseRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); |
| builder.setInsertionPointAfter(if1); |
| return std::make_pair(if1.getResult(0), if1.getResult(1)); |
| }; |
| |
| Value childIdx; |
| std::tie(child, childIdx) = getLargerChild(child); |
| |
| // While (data[start] < data[childIndex]). |
| SmallVector<Type, 3> types(3, child.getType()); |
| scf::WhileOp whileOp = builder.create<scf::WhileOp>( |
| loc, types, SmallVector<Value, 2>{start, child, childIdx}); |
| |
| // The before-region of the WhileOp. |
| SmallVector<Location, 3> locs(3, loc); |
| Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); |
| builder.setInsertionPointToEnd(before); |
| start = before->getArgument(0); |
| childIdx = before->getArgument(2); |
| compareOperands[0] = start; |
| compareOperands[1] = childIdx; |
| Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); |
| builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); |
| |
| // The after-region of the WhileOp. |
| Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); |
| start = after->getArgument(0); |
| child = after->getArgument(1); |
| childIdx = after->getArgument(2); |
| SmallVector<Value> swapOperands{start, childIdx}; |
| swapOperands.append(args.begin() + xStartIdx, args.end()); |
| createSwap(builder, loc, swapOperands, xPerm, ny); |
| start = childIdx; |
| Value cond2 = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); |
| scf::IfOp if2 = builder.create<scf::IfOp>( |
| loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); |
| builder.setInsertionPointToStart(&if2.getThenRegion().front()); |
| auto [newChild, newChildIdx] = getLargerChild(child); |
| builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); |
| builder.setInsertionPointToStart(&if2.getElseRegion().front()); |
| builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); |
| builder.setInsertionPointAfter(if2); |
| builder.create<scf::YieldOp>( |
| loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); |
| |
| builder.setInsertionPointAfter(ifN); |
| builder.create<func::ReturnOp>(loc); |
| } |
| |
| /// Creates a function to perform heap sort on the values in the range of index |
| /// [lo, hi) with the assumption hi - lo >= 2. |
| // |
| // The generate IR corresponds to this C like algorithm: |
| // void heapSort(lo, hi, data) { |
| // n = hi - lo |
| // for i = (n-2)/2 downto 0 |
| // shiftDown(lo, lo+i, n) |
| // |
| // for l = n downto 2 |
| // swap(lo, lo+l-1) |
| // shiftdown(lo, lo, l-1) |
| // } |
| static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, |
| func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| uint32_t nTrailingP) { |
| // Heap sort function doesn't have trailing parameters. |
| (void)nTrailingP; |
| assert(nTrailingP == 0); |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| Block *entryBlock = func.addEntryBlock(); |
| builder.setInsertionPointToStart(entryBlock); |
| |
| Location loc = func.getLoc(); |
| ValueRange args = entryBlock->getArguments(); |
| Value lo = args[loIdx]; |
| Value hi = args[hiIdx]; |
| Value n = builder.create<arith::SubIOp>(loc, hi, lo); |
| |
| // For i = (n-2)/2 downto 0. |
| Value c0 = constantIndex(builder, loc, 0); |
| Value c1 = constantIndex(builder, loc, 1); |
| Value s = createSubTwoDividedByTwo(builder, loc, n); |
| Value up = builder.create<arith::AddIOp>(loc, s, c1); |
| scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); |
| builder.setInsertionPointToStart(forI.getBody()); |
| Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); |
| Value lopi = builder.create<arith::AddIOp>(loc, lo, i); |
| SmallVector<Value> shiftDownOperands = {lo, lopi}; |
| shiftDownOperands.append(args.begin() + xStartIdx, args.end()); |
| shiftDownOperands.push_back(n); |
| FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( |
| builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, |
| shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); |
| builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), |
| shiftDownOperands); |
| |
| builder.setInsertionPointAfter(forI); |
| // For l = n downto 2. |
| up = builder.create<arith::SubIOp>(loc, n, c1); |
| scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); |
| builder.setInsertionPointToStart(forL.getBody()); |
| Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); |
| Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); |
| loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); |
| SmallVector<Value> swapOperands{lo, loplm1}; |
| swapOperands.append(args.begin() + xStartIdx, args.end()); |
| createSwap(builder, loc, swapOperands, xPerm, ny); |
| shiftDownOperands[1] = lo; |
| shiftDownOperands[shiftDownOperands.size() - 1] = |
| builder.create<arith::SubIOp>(loc, l, c1); |
| builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), |
| shiftDownOperands); |
| |
| builder.setInsertionPointAfter(forL); |
| builder.create<func::ReturnOp>(loc); |
| } |
| |
| /// A helper for generating code to perform quick sort. It partitions [lo, hi), |
| /// recursively calls quick sort to process the smaller partition and returns |
| /// the bigger partition to be processed by the enclosed while-loop. |
| static std::pair<Value, Value> |
| createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, |
| ValueRange args, AffineMap xPerm, uint64_t ny, |
| uint32_t nTrailingP) { |
| MLIRContext *context = module.getContext(); |
| Location loc = func.getLoc(); |
| Value lo = args[loIdx]; |
| Value hi = args[hiIdx]; |
| SmallVector<Type, 2> types(2, lo.getType()); // Only two types. |
| |
| FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( |
| builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, |
| ny, args.drop_back(nTrailingP), createPartitionFunc); |
| Value p = builder |
| .create<func::CallOp>(loc, partitionFunc, |
| TypeRange{IndexType::get(context)}, |
| args.drop_back(nTrailingP)) |
| .getResult(0); |
| |
| Value lenLow = builder.create<arith::SubIOp>(loc, p, lo); |
| Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p); |
| // Partition already sorts array with len <= 2 |
| Value c2 = constantIndex(builder, loc, 2); |
| Value len = builder.create<arith::SubIOp>(loc, hi, lo); |
| Value lenGtTwo = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2); |
| scf::IfOp ifLenGtTwo = |
| builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true); |
| builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); |
| // Returns an empty range to mark the entire region is fully sorted. |
| builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
| |
| // Else len > 2, need recursion. |
| builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); |
| Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, |
| lenLow, lenHigh); |
| |
| Value c0 = constantIndex(builder, loc, 0); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); |
| |
| auto mayRecursion = [&](Value low, Value high, Value len) { |
| Value cond = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0); |
| scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| SmallVector<Value> operands{low, high}; |
| operands.append(args.begin() + xStartIdx, args.end()); |
| builder.create<func::CallOp>(loc, func, operands); |
| builder.setInsertionPointAfter(ifOp); |
| }; |
| |
| // Recursively call quickSort to process the smaller partition and return |
| // the bigger partition to be processed by the enclosed while-loop. |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| mayRecursion(lo, p, lenLow); |
| builder.create<scf::YieldOp>(loc, ValueRange{p, hi}); |
| |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| mayRecursion(p, hi, lenHigh); |
| builder.create<scf::YieldOp>(loc, ValueRange{lo, p}); |
| |
| builder.setInsertionPointAfter(ifOp); |
| builder.create<scf::YieldOp>(loc, ifOp.getResults()); |
| |
| builder.setInsertionPointAfter(ifLenGtTwo); |
| return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); |
| } |
| |
| /// Creates a function to perform insertion sort on the values in the range of |
| /// index [lo, hi). |
| // |
| // The generate IR corresponds to this C like algorithm: |
| // void insertionSort(lo, hi, data) { |
| // for (i = lo+1; i < hi; i++) { |
| // d = data[i]; |
| // p = binarySearch(lo, i-1, data) |
| // for (j = 0; j > i - p; j++) |
| // data[i-j] = data[i-j-1] |
| // data[p] = d |
| // } |
| // } |
| static void createSortStableFunc(OpBuilder &builder, ModuleOp module, |
| func::FuncOp func, AffineMap xPerm, |
| uint64_t ny, uint32_t nTrailingP) { |
| // Stable sort function doesn't use trailing parameters. |
| (void)nTrailingP; |
| assert(nTrailingP == 0); |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| Block *entryBlock = func.addEntryBlock(); |
| builder.setInsertionPointToStart(entryBlock); |
| |
| MLIRContext *context = module.getContext(); |
| Location loc = func.getLoc(); |
| ValueRange args = entryBlock->getArguments(); |
| Value c1 = constantIndex(builder, loc, 1); |
| Value lo = args[loIdx]; |
| Value hi = args[hiIdx]; |
| Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); |
| |
| // Start the outer for-stmt with induction variable i. |
| scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); |
| builder.setInsertionPointToStart(forOpI.getBody()); |
| Value i = forOpI.getInductionVar(); |
| |
| // Binary search to find the insertion point p. |
| SmallVector<Value> operands{lo, i}; |
| operands.append(args.begin() + xStartIdx, args.end()); |
| FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( |
| builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, |
| xPerm, ny, operands, createBinarySearchFunc); |
| Value p = builder |
| .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, |
| operands) |
| .getResult(0); |
| |
| // Move the value at data[i] to a temporary location. |
| operands[0] = operands[1] = i; |
| SmallVector<Value> d; |
| forEachIJPairInAllBuffers( |
| builder, loc, operands, xPerm, ny, |
| [&](uint64_t unused, Value i, Value unused2, Value buffer) { |
| d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); |
| }); |
| |
| // Start the inner for-stmt with induction variable j, for moving data[p..i) |
| // to data[p+1..i+1). |
| Value imp = builder.create<arith::SubIOp>(loc, i, p); |
| Value c0 = constantIndex(builder, loc, 0); |
| scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); |
| builder.setInsertionPointToStart(forOpJ.getBody()); |
| Value j = forOpJ.getInductionVar(); |
| Value imj = builder.create<arith::SubIOp>(loc, i, j); |
| operands[1] = imj; |
| operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); |
| forEachIJPairInAllBuffers( |
| builder, loc, operands, xPerm, ny, |
| [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { |
| Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); |
| builder.create<memref::StoreOp>(loc, t, buffer, imj); |
| }); |
| |
| // Store the value at data[i] to data[p]. |
| builder.setInsertionPointAfter(forOpJ); |
| operands[0] = operands[1] = p; |
| forEachIJPairInAllBuffers( |
| builder, loc, operands, xPerm, ny, |
| [&](uint64_t k, Value p, Value usused, Value buffer) { |
| builder.create<memref::StoreOp>(loc, d[k], buffer, p); |
| }); |
| |
| builder.setInsertionPointAfter(forOpI); |
| builder.create<func::ReturnOp>(loc); |
| } |
| |
| /// Creates a function to perform quick sort or a hybrid quick sort on the |
| /// values in the range of index [lo, hi). |
| // |
| // |
| // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: |
| // void quickSort(lo, hi, data) { |
| // while (lo + 1 < hi) { |
| // p = partition(low, high, data); |
| // if (len(lo, p) < len(p+1, hi)) { |
| // quickSort(lo, p, data); |
| // lo = p+1; |
| // } else { |
| // quickSort(p + 1, hi, data); |
| // hi = p; |
| // } |
| // } |
| // } |
| // |
| // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: |
| // void hybridQuickSort(lo, hi, data, depthLimit) { |
| // while (lo + 1 < hi) { |
| // len = hi - lo; |
| // if (len <= limit) { |
| // insertionSort(lo, hi, data); |
| // } else { |
| // depthLimit --; |
| // if (depthLimit <= 0) { |
| // heapSort(lo, hi, data); |
| // } else { |
| // p = partition(low, high, data); |
| // if (len(lo, p) < len(p+1, hi)) { |
| // quickSort(lo, p, data, depthLimit); |
| // lo = p+1; |
| // } else { |
| // quickSort(p + 1, hi, data, depthLimit); |
| // hi = p; |
| // } |
| // } |
| // } |
| // } |
| // } |
| // |
| static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, |
| func::FuncOp func, AffineMap xPerm, uint64_t ny, |
| uint32_t nTrailingP) { |
| assert(nTrailingP == 1 || nTrailingP == 0); |
| bool isHybrid = (nTrailingP == 1); |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| Block *entryBlock = func.addEntryBlock(); |
| builder.setInsertionPointToStart(entryBlock); |
| |
| Location loc = func.getLoc(); |
| SmallVector<Value> args; |
| args.append(entryBlock->getArguments().begin(), |
| entryBlock->getArguments().end()); |
| Value lo = args[loIdx]; |
| Value hi = args[hiIdx]; |
| SmallVector<Type, 2> types(2, lo.getType()); // Only two types. |
| scf::WhileOp whileOp = |
| builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi}); |
| |
| // The before-region of the WhileOp. |
| Block *before = |
| builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); |
| builder.setInsertionPointToEnd(before); |
| lo = before->getArgument(0); |
| hi = before->getArgument(1); |
| Value loP1 = |
| builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); |
| Value needSort = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi); |
| builder.create<scf::ConditionOp>(loc, needSort, before->getArguments()); |
| |
| // The after-region of the WhileOp. |
| Block *after = |
| builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); |
| builder.setInsertionPointToEnd(after); |
| lo = after->getArgument(0); |
| hi = after->getArgument(1); |
| args[0] = lo; |
| args[1] = hi; |
| |
| if (isHybrid) { |
| Value len = builder.create<arith::SubIOp>(loc, hi, lo); |
| Value lenLimit = constantIndex(builder, loc, 30); |
| Value lenCond = builder.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::ule, len, lenLimit); |
| scf::IfOp lenIf = |
| builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true); |
| |
| // When len <= limit. |
| builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); |
| FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( |
| builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, |
| ValueRange(args).drop_back(nTrailingP), createSortStableFunc); |
| builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), |
| ValueRange(args).drop_back(nTrailingP)); |
| builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
| |
| // When len > limit. |
| builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); |
| Value depthLimit = args.back(); |
| depthLimit = builder.create<arith::SubIOp>(loc, depthLimit, |
| constantI64(builder, loc, 1)); |
| Value depthCond = |
| builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, |
| depthLimit, constantI64(builder, loc, 0)); |
| scf::IfOp depthIf = |
| builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true); |
| |
| // When depth exceeds limit. |
| builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); |
| FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( |
| builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, |
| ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); |
| builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), |
| ValueRange(args).drop_back(nTrailingP)); |
| builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); |
| |
| // When depth doesn't exceed limit. |
| builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); |
| args.back() = depthLimit; |
| std::tie(lo, hi) = |
| createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); |
| builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
| |
| builder.setInsertionPointAfter(depthIf); |
| lo = depthIf.getResult(0); |
| hi = depthIf.getResult(1); |
| builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
| |
| builder.setInsertionPointAfter(lenIf); |
| lo = lenIf.getResult(0); |
| hi = lenIf.getResult(1); |
| } else { |
| std::tie(lo, hi) = |
| createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); |
| } |
| |
| // New [lo, hi) for the next while-loop iteration. |
| builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); |
| |
| // After the while-loop. |
| builder.setInsertionPointAfter(whileOp); |
| builder.create<func::ReturnOp>(loc); |
| } |
| |
| /// Implements the rewriting for operator sort and sort_coo. |
| template <typename OpTy> |
| LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, |
| uint64_t ny, PatternRewriter &rewriter) { |
| Location loc = op.getLoc(); |
| SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; |
| |
| // Convert `values` to have dynamic shape and append them to `operands`. |
| for (Value v : xys) { |
| auto mtp = getMemRefType(v); |
| if (!mtp.isDynamicDim(0)) { |
| auto newMtp = |
| MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); |
| v = rewriter.create<memref::CastOp>(loc, newMtp, v); |
| } |
| operands.push_back(v); |
| } |
| |
| auto insertPoint = op->template getParentOfType<func::FuncOp>(); |
| if (!insertPoint) |
| return failure(); |
| |
| SmallString<32> funcName; |
| FuncGeneratorType funcGenerator; |
| uint32_t nTrailingP = 0; |
| switch (op.getAlgorithm()) { |
| case SparseTensorSortKind::HybridQuickSort: { |
| funcName = kHybridQuickSortFuncNamePrefix; |
| funcGenerator = createQuickSortFunc; |
| nTrailingP = 1; |
| // As a heuristics, set depthLimit = 2 * log2(n). |
| Value lo = operands[loIdx]; |
| Value hi = operands[hiIdx]; |
| Value len = rewriter.create<arith::IndexCastOp>( |
| loc, rewriter.getI64Type(), |
| rewriter.create<arith::SubIOp>(loc, hi, lo)); |
| Value depthLimit = rewriter.create<arith::SubIOp>( |
| loc, constantI64(rewriter, loc, 64), |
| rewriter.create<math::CountLeadingZerosOp>(loc, len)); |
| operands.push_back(depthLimit); |
| break; |
| } |
| case SparseTensorSortKind::QuickSort: |
| funcName = kQuickSortFuncNamePrefix; |
| funcGenerator = createQuickSortFunc; |
| break; |
| case SparseTensorSortKind::InsertionSortStable: |
| funcName = kSortStableFuncNamePrefix; |
| funcGenerator = createSortStableFunc; |
| break; |
| case SparseTensorSortKind::HeapSort: |
| funcName = kHeapSortFuncNamePrefix; |
| funcGenerator = createHeapSortFunc; |
| break; |
| } |
| |
| FlatSymbolRefAttr func = |
| getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, |
| xPerm, ny, operands, funcGenerator, nTrailingP); |
| rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); |
| return success(); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // The actual sparse buffer rewriting rules. |
| //===---------------------------------------------------------------------===// |
| |
| namespace { |
| /// Sparse rewriting rule for the push_back operator. |
| struct PushBackRewriter : OpRewritePattern<PushBackOp> { |
| public: |
| using OpRewritePattern<PushBackOp>::OpRewritePattern; |
| PushBackRewriter(MLIRContext *context, bool enableInit) |
| : OpRewritePattern(context), enableBufferInitialization(enableInit) {} |
| LogicalResult matchAndRewrite(PushBackOp op, |
| PatternRewriter &rewriter) const override { |
| // Rewrite push_back(buffer, value, n) to: |
| // new_size = size(buffer) + n |
| // if (new_size > capacity(buffer)) |
| // while new_size > new_capacity |
| // new_capacity = new_capacity*2 |
| // new_buffer = realloc(buffer, new_capacity) |
| // buffer = new_buffer |
| // subBuffer = subviewof(buffer) |
| // linalg.fill subBuffer value |
| // |
| // size(buffer) += n |
| // |
| // The capacity check is skipped when the attribute inbounds is presented. |
| Location loc = op->getLoc(); |
| Value c0 = constantIndex(rewriter, loc, 0); |
| Value buffer = op.getInBuffer(); |
| Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); |
| Value size = op.getCurSize(); |
| Value value = op.getValue(); |
| |
| Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); |
| Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); |
| auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); |
| bool nIsOne = (nValue && nValue.value() == 1); |
| |
| if (!op.getInbounds()) { |
| Value cond = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::ugt, newSize, capacity); |
| |
| Value c2 = constantIndex(rewriter, loc, 2); |
| auto bufferType = |
| MemRefType::get({ShapedType::kDynamic}, value.getType()); |
| scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, |
| /*else=*/true); |
| // True branch. |
| rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| if (nIsOne) { |
| capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); |
| } else { |
| // Use a do-while loop to calculate the new capacity as follows: |
| // do { new_capacity *= 2 } while (size > new_capacity) |
| scf::WhileOp whileOp = |
| rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); |
| |
| // The before-region of the WhileOp. |
| Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, |
| {capacity.getType()}, {loc}); |
| rewriter.setInsertionPointToEnd(before); |
| |
| capacity = |
| rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); |
| cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, |
| newSize, capacity); |
| rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); |
| // The after-region of the WhileOp. |
| Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, |
| {capacity.getType()}, {loc}); |
| rewriter.setInsertionPointToEnd(after); |
| rewriter.create<scf::YieldOp>(loc, after->getArguments()); |
| |
| rewriter.setInsertionPointAfter(whileOp); |
| capacity = whileOp.getResult(0); |
| } |
| |
| Value newBuffer = |
| rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); |
| if (enableBufferInitialization) { |
| Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); |
| Value fillValue = constantZero(rewriter, loc, value.getType()); |
| Value subBuffer = rewriter.create<memref::SubViewOp>( |
| loc, newBuffer, /*offset=*/ValueRange{newSize}, |
| /*size=*/ValueRange{fillSize}, |
| /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); |
| rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); |
| } |
| rewriter.create<scf::YieldOp>(loc, newBuffer); |
| |
| // False branch. |
| rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| rewriter.create<scf::YieldOp>(loc, buffer); |
| |
| // Prepare for adding the value to the end of the buffer. |
| rewriter.setInsertionPointAfter(ifOp); |
| buffer = ifOp.getResult(0); |
| } |
| |
| // Add the value to the end of the buffer. |
| if (nIsOne) { |
| rewriter.create<memref::StoreOp>(loc, value, buffer, size); |
| } else { |
| Value subBuffer = rewriter.create<memref::SubViewOp>( |
| loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, |
| /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); |
| rewriter.create<linalg::FillOp>(loc, value, subBuffer); |
| } |
| |
| // Update the buffer size. |
| rewriter.replaceOp(op, {buffer, newSize}); |
| return success(); |
| } |
| |
| private: |
| bool enableBufferInitialization; |
| }; |
| |
| /// Sparse rewriting rule for the sort_coo operator. |
| struct SortRewriter : public OpRewritePattern<SortOp> { |
| public: |
| using OpRewritePattern<SortOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(SortOp op, |
| PatternRewriter &rewriter) const override { |
| SmallVector<Value> xys; |
| xys.push_back(op.getXy()); |
| xys.append(op.getYs().begin(), op.getYs().end()); |
| |
| auto xPerm = op.getPermMap(); |
| uint64_t ny = 0; |
| if (auto nyAttr = op.getNyAttr()) |
| ny = nyAttr.getInt(); |
| |
| return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); |
| } |
| }; |
| |
| } // namespace |
| |
| //===---------------------------------------------------------------------===// |
| // Methods that add patterns described in this file to a pattern list. |
| //===---------------------------------------------------------------------===// |
| |
| void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, |
| bool enableBufferInitialization) { |
| patterns.add<PushBackRewriter>(patterns.getContext(), |
| enableBufferInitialization); |
| patterns.add<SortRewriter>(patterns.getContext()); |
| } |