| //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Utils/CodegenUtils.h" |
| |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "llvm/Support/FormatVariadic.h" |
| |
| using namespace mlir; |
| using namespace sparse_tensor; |
| |
| //===----------------------------------------------------------------------===// |
| // Helper methods. |
| //===----------------------------------------------------------------------===// |
| |
| // Convert type range to new types range, with sparse tensors externalized. |
| static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, |
| SmallVectorImpl<Type> *extraTypes, bool directOut) { |
| for (auto type : types) { |
| // All "dense" data passes through unmodified. |
| if (!getSparseTensorEncoding(type)) { |
| convTypes.push_back(type); |
| continue; |
| } |
| |
| // Convert the external representations of the pos/crd/val arrays. |
| const SparseTensorType stt(cast<RankedTensorType>(type)); |
| foreachFieldAndTypeInSparseTensor( |
| stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, |
| SparseTensorFieldKind kind, |
| Level, LevelType) { |
| if (kind == SparseTensorFieldKind::PosMemRef || |
| kind == SparseTensorFieldKind::CrdMemRef || |
| kind == SparseTensorFieldKind::ValMemRef) { |
| auto rtp = cast<ShapedType>(t); |
| if (!directOut) { |
| rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
| if (extraTypes) |
| extraTypes->push_back(rtp); |
| } |
| convTypes.push_back(rtp); |
| } |
| return true; |
| }); |
| } |
| } |
| |
| // Convert input and output values to [dis]assemble ops for sparse tensors. |
| static void convVals(OpBuilder &builder, Location loc, TypeRange types, |
| ValueRange fromVals, ValueRange extraVals, |
| SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn, |
| bool directOut) { |
| unsigned idx = 0; |
| for (auto type : types) { |
| // All "dense" data passes through unmodified. |
| if (!getSparseTensorEncoding(type)) { |
| toVals.push_back(fromVals[idx++]); |
| continue; |
| } |
| // Handle sparse data. |
| auto rtp = cast<RankedTensorType>(type); |
| const SparseTensorType stt(rtp); |
| SmallVector<Value> inputs; |
| SmallVector<Type> retTypes; |
| SmallVector<Type> cntTypes; |
| if (!isIn) |
| inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble |
| |
| // Collect the external representations of the pos/crd/val arrays. |
| foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, |
| SparseTensorFieldKind kind, |
| Level lv, LevelType) { |
| if (kind == SparseTensorFieldKind::PosMemRef || |
| kind == SparseTensorFieldKind::CrdMemRef || |
| kind == SparseTensorFieldKind::ValMemRef) { |
| if (isIn) { |
| inputs.push_back(fromVals[idx++]); |
| } else if (directOut) { |
| Value mem; |
| if (kind == SparseTensorFieldKind::PosMemRef) |
| mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0], |
| lv); |
| else if (kind == SparseTensorFieldKind::CrdMemRef) |
| mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0], |
| lv); |
| else |
| mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]); |
| toVals.push_back(mem); |
| } else { |
| ShapedType rtp = cast<ShapedType>(t); |
| rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
| inputs.push_back(extraVals[extra++]); |
| retTypes.push_back(rtp); |
| cntTypes.push_back(builder.getIndexType()); |
| } |
| } |
| return true; |
| }); |
| |
| if (isIn) { |
| // Assemble multiple inputs into a single sparse tensor. |
| auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); |
| toVals.push_back(a.getResult()); |
| } else if (!directOut) { |
| // Disassemble a single sparse input into multiple outputs. |
| // Note that this includes the counters, which are dropped. |
| unsigned len = retTypes.size(); |
| retTypes.append(cntTypes); |
| auto d = |
| builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); |
| for (unsigned i = 0; i < len; i++) |
| toVals.push_back(d.getResult(i)); |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Rewriting rules. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| // A rewriting rules that converts public entry methods that use sparse tensors |
| // as input parameters and/or output return values into wrapper methods that |
| // [dis]assemble the individual tensors that constitute the actual storage used |
| // externally into MLIR sparse tensors before calling the original method. |
| // |
| // In particular, each sparse tensor input |
| // |
| // void foo(..., t, ...) { } |
| // |
| // makes the original foo() internal and adds the following wrapper method |
| // |
| // void foo(..., t1..tn, ...) { |
| // t = assemble t1..tn |
| // _internal_foo(..., t, ...) |
| // } |
| // |
| // and likewise, each output tensor |
| // |
| // ... T ... bar(...) { return ..., t, ...; } |
| // |
| // makes the original bar() internal and adds the following wrapper method |
| // |
| // ... T1..TN ... bar(..., t1'..tn') { |
| // ..., t, ... = _internal_bar(...) |
| // t1..tn = disassemble t, t1'..tn' |
| // return ..., t1..tn, ... |
| // } |
| // |
| // (with a direct-out variant without the disassemble). |
| // |
| struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| SparseFuncAssembler(MLIRContext *context, bool dO) |
| : OpRewritePattern(context), directOut(dO) {} |
| |
| LogicalResult matchAndRewrite(func::FuncOp funcOp, |
| PatternRewriter &rewriter) const override { |
| // Only rewrite public entry methods. |
| if (funcOp.isPrivate()) |
| return failure(); |
| |
| // Translate sparse tensor types to external types. |
| SmallVector<Type> inputTypes; |
| SmallVector<Type> outputTypes; |
| SmallVector<Type> extraTypes; |
| convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false); |
| convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut); |
| |
| // Only sparse inputs or outputs need a wrapper method. |
| if (inputTypes.size() == funcOp.getArgumentTypes().size() && |
| outputTypes.size() == funcOp.getResultTypes().size()) |
| return failure(); |
| |
| // Modify the original method into an internal, private method. |
| auto orgName = funcOp.getName(); |
| std::string wrapper = llvm::formatv("_internal_{0}", orgName).str(); |
| funcOp.setName(wrapper); |
| funcOp.setPrivate(); |
| |
| // Start the new public wrapper method with original name. |
| Location loc = funcOp.getLoc(); |
| ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); |
| MLIRContext *context = modOp.getContext(); |
| OpBuilder moduleBuilder(modOp.getBodyRegion()); |
| unsigned extra = inputTypes.size(); |
| inputTypes.append(extraTypes); |
| auto func = moduleBuilder.create<func::FuncOp>( |
| loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); |
| func.setPublic(); |
| |
| // Construct new wrapper method body. |
| OpBuilder::InsertionGuard insertionGuard(rewriter); |
| Block *body = func.addEntryBlock(); |
| rewriter.setInsertionPointToStart(body); |
| |
| // Convert inputs. |
| SmallVector<Value> inputs; |
| convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), |
| ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); |
| |
| // Call the original, now private method. A subsequent inlining pass can |
| // determine whether cloning the method body in place is worthwhile. |
| auto org = SymbolRefAttr::get(context, wrapper); |
| auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, |
| inputs); |
| |
| // Convert outputs and return. |
| SmallVector<Value> outputs; |
| convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), |
| body->getArguments(), outputs, extra, /*isIn=*/false, directOut); |
| rewriter.create<func::ReturnOp>(loc, outputs); |
| |
| // Finally, migrate a potential c-interface property. |
| if (funcOp->getAttrOfType<UnitAttr>( |
| LLVM::LLVMDialect::getEmitCWrapperAttrName())) { |
| func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| UnitAttr::get(context)); |
| funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); |
| } |
| return success(); |
| } |
| |
| private: |
| const bool directOut; |
| }; |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Public method for populating conversion rules. |
| //===----------------------------------------------------------------------===// |
| |
| void mlir::populateSparseAssembler(RewritePatternSet &patterns, |
| bool directOut) { |
| patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut); |
| } |