[xla-next][sparse] sort the sparse custom call rewriter in alphabet order
PiperOrigin-RevId: 537905129
GitOrigin-RevId: 5c17bd6a558922fce92a5bb3f9457bb8fb73165a
Change-Id: I96c50a84508b4c85c39af440672e868f43daaf6c
diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc
index 40eef9f..cedfb9f 100644
--- a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc
+++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc
@@ -70,6 +70,15 @@
values.append(range.begin(), range.end());
}
+Value getEmptyTensor(OpBuilder& b, Location loc, RankedTensorType type) {
+ auto t = b.create<tensor::EmptyOp>(loc, type.getShape(),
+ type.getElementType(), ValueRange{});
+ auto zero = b.getZeroAttr(type.getElementType());
+ auto c0 = b.create<arith::ConstantOp>(loc, zero);
+ return b.create<linalg::FillOp>(loc, ValueRange{c0}, ValueRange{t})
+ .getResult(0);
+}
+
struct SparseBatchedPackCallRewriter {
LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
assert(op.getResults().size() == 1 && "Must be packing into one tensor");
@@ -81,118 +90,6 @@
}
};
-struct SparseUnpackCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getResults().size() + 1 == op.getInputs().size());
- // Both jax.BCSR and jax.BCOO has three memref fields.
- SmallVector<Type, 3> unpack_ret_tp(op.getResults().getTypes());
- Value tensor = op.getInputs()[0];
- Value out_vals = op.getInputs()[1];
- ValueRange out_lvls = op.getInputs().drop_front(2);
- // Constructs the UnpackOp.
- auto unpack_op = rewriter.create<sparse_tensor::UnpackOp>(
- op.getLoc(), unpack_ret_tp, tensor, out_vals, out_lvls);
- rewriter.replaceOp(op, unpack_op.getResults());
- return success();
- }
-};
-
-struct SparseTransposeCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 2 && "Need argument and permutation");
- assert(op.getResults().size() == 1 && "Need one output tensor");
-
- // The permutation is passed in as a constant of dense int elements.
- auto permutation_constant =
- op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
- auto permutation =
- permutation_constant.getValue().cast<DenseIntElementsAttr>();
-
- // Reconstruct the transpose operation.
- Value ret_sp_tensor = op.getResults()[0];
- rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
- op, ret_sp_tensor.getType(), op.getInputs()[0], permutation);
- return success();
- }
-};
-
-struct SparseDotCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 6 && "Need arguments and metadata");
- assert(op.getResults().size() == 1 && "Need one output tensor");
- SmallVector<int64_t> lhs_contr, rhs_contr, lhs_batch, rhs_batch;
- getIntegersFromDenseElements(op.getInputs()[2], lhs_contr);
- getIntegersFromDenseElements(op.getInputs()[3], rhs_contr);
- getIntegersFromDenseElements(op.getInputs()[4], lhs_batch);
- getIntegersFromDenseElements(op.getInputs()[5], rhs_batch);
- auto dot_dims = mlir::mhlo::DotDimensionNumbersAttr::get(
- op.getContext(), lhs_batch, rhs_batch, lhs_contr, rhs_contr);
- Value ret_sp_tensor = op.getResults()[0];
- rewriter.replaceOpWithNewOp<mhlo::DotGeneralOp>(
- op, ret_sp_tensor.getType(), op.getInputs()[0], op.getInputs()[1],
- dot_dims, /*defaultPrecision*/ ArrayAttr());
- return success();
- }
-};
-
-struct SparseConcatenateCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getResults().size() == 1 && "Need one output tensor");
-
- // The concatenation dimension.
- auto concat_dim = op.getInputs().back().getDefiningOp<mhlo::ConstantOp>();
- auto concat_dim_attr = concat_dim.getValue().cast<DenseIntElementsAttr>();
- // Reconstruct the concatenate operation.
- Value ret_sp_tensor = op.getResults()[0];
- // Depending on test setup, we can get either a 32-bit integer or a 64-bit
- // integer.
- if (concat_dim_attr.getElementType().isInteger(32)) {
- rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
- op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
- rewriter.getIndexAttr(concat_dim_attr.getValues<uint32_t>()[0]));
- } else {
- assert(concat_dim_attr.getElementType().isInteger(64));
- rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
- op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
- rewriter.getIndexAttr(concat_dim_attr.getValues<uint64_t>()[0]));
- }
-
- return success();
- }
-};
-
-struct SparseBroadcastInDimCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 2 &&
- "Need argument and broadcast dimensions");
- assert(op.getResults().size() == 1 && "Need one output tensor");
-
- // Broadcast dimensions are passed in as a constant of dense int elements.
- auto dims_constant = op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
- auto broadcast_dimensions =
- dims_constant.getValue().cast<DenseIntElementsAttr>();
-
- // Reconstruct the broadcast_in_dim operation.
- Value ret_sp_tensor = op.getResults()[0];
- rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(
- op, ret_sp_tensor.getType(), op.getInputs()[0], broadcast_dimensions);
- return success();
- }
-};
-
-template <typename unaryChlo>
-struct SparseUnaryChloCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 1 && "Need one argument");
- assert(op.getResults().size() == 1 && "Need one output tensor");
- // Reconstruct the unary chlo operation.
- Value ret_sp_tensor = op.getResults()[0];
- rewriter.replaceOpWithNewOp<unaryChlo>(op, ret_sp_tensor.getType(),
- op.getInputs()[0]);
- return success();
- }
-};
-
template <typename BinaryMhlo>
struct SparseBinaryCallRewriter {
LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
@@ -206,52 +103,86 @@
}
};
-struct SparseSliceCallRewriter {
+struct SparseBroadcastInDimCallRewriter {
LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 4 &&
- "Need one operand and three slicing parameters");
+ assert(op.getInputs().size() == 2 &&
+ "Need argument and broadcast dimensions");
assert(op.getResults().size() == 1 && "Need one output tensor");
+ // Broadcast dimensions are passed in as a constant of dense int elements.
+ auto dims_constant = op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
+ auto broadcast_dimensions =
+ dims_constant.getValue().cast<DenseIntElementsAttr>();
+ // Reconstruct the broadcast_in_dim operation.
+ Value ret_sp_tensor = op.getResults()[0];
+ rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(
+ op, ret_sp_tensor.getType(), op.getInputs()[0], broadcast_dimensions);
+ return success();
+ }
+};
- auto ctx = op.getContext();
- auto loc = op.getLoc();
- auto retTp = op.getResults().getTypes()[0].cast<RankedTensorType>();
-
- auto offsets = getDenseIntAttrFromConstant(op.getInputs()[1]);
- auto strides = getDenseIntAttrFromConstant(op.getInputs()[3]);
-
- assert(offsets.getNumElements() == strides.getNumElements() &&
- offsets.getNumElements() == retTp.getRank());
-
- SmallVector<sparse_tensor::SparseTensorDimSliceAttr> slice_attrs;
- SmallVector<int64_t> static_offsets, static_sizes, static_strides;
- for (auto [offset, size, stride] :
- llvm::zip(offsets, retTp.getShape(), strides)) {
- int64_t o = offset.getZExtValue(), s = stride.getZExtValue();
- // Converts limits to sizes.
- slice_attrs.push_back(
- sparse_tensor::SparseTensorDimSliceAttr::get(ctx, o, size, s));
- static_offsets.push_back(o);
- static_sizes.push_back(size);
- static_strides.push_back(s);
+struct SparseConcatenateCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ // The concatenation dimension.
+ auto concat_dim = op.getInputs().back().getDefiningOp<mhlo::ConstantOp>();
+ auto concat_dim_attr = concat_dim.getValue().cast<DenseIntElementsAttr>();
+ // Reconstruct the concatenate operation.
+ Value ret_sp_tensor = op.getResults()[0];
+ // Depending on test setup, we can get either a 32-bit integer or a 64-bit
+ // integer.
+ if (concat_dim_attr.getElementType().isInteger(32)) {
+ rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
+ op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
+ rewriter.getIndexAttr(concat_dim_attr.getValues<uint32_t>()[0]));
+ } else {
+ assert(concat_dim_attr.getElementType().isInteger(64));
+ rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
+ op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
+ rewriter.getIndexAttr(concat_dim_attr.getValues<uint64_t>()[0]));
}
+ return success();
+ }
+};
- auto srcEnc =
- retTp.getEncoding().cast<sparse_tensor::SparseTensorEncodingAttr>();
- // TODO(peiming): add a getSliceEncodingFrom into MLIR upstream.
- auto sliceEnc = sparse_tensor::SparseTensorEncodingAttr::get(
- ctx, srcEnc.getLvlTypes(), srcEnc.getDimToLvl(), srcEnc.getPosWidth(),
- srcEnc.getCrdWidth(), slice_attrs);
- auto sliceTp = RankedTensorType::get(retTp.getShape(),
- retTp.getElementType(), sliceEnc);
+struct SparseConvCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getInputs().size() == 2 && "Need two input tensors");
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ auto rtp = op.getResults()[0].getType().cast<RankedTensorType>();
+ rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(
+ op, op.getInputs(), getEmptyTensor(rewriter, op.getLoc(), rtp));
+ return success();
+ }
+};
- auto slice = rewriter.create<tensor::ExtractSliceOp>(
- loc, sliceTp, op.getInputs()[0], ValueRange(), ValueRange(),
- ValueRange(), static_offsets, static_sizes, static_strides);
+struct SparseConvertCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getInputs().size() == 1 && "Need one input tensor");
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ Value ret_sp_tensor = op.getResults()[0];
+ rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
+ op, ret_sp_tensor.getType(), op.getInputs()[0]);
+ return success();
+ }
+};
- // TODO(peiming): This weakens the performance benefit we get from the
- // sparse compiler by forcing every slice to be materizalized while the
- // sparse compiler supports view-based slice.
- rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(op, retTp, slice);
+struct SparseDotCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getInputs().size() == 6 && "Need arguments and metadata");
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ SmallVector<int64_t> lhs_contr, rhs_contr, lhs_batch, rhs_batch;
+ getIntegersFromDenseElements(op.getInputs()[2], lhs_contr);
+ getIntegersFromDenseElements(op.getInputs()[3], rhs_contr);
+ getIntegersFromDenseElements(op.getInputs()[4], lhs_batch);
+ getIntegersFromDenseElements(op.getInputs()[5], rhs_batch);
+ auto dot_dims = mlir::mhlo::DotDimensionNumbersAttr::get(
+ op.getContext(), lhs_batch, rhs_batch, lhs_contr, rhs_contr);
+ Value ret_sp_tensor = op.getResults()[0];
+ rewriter.replaceOpWithNewOp<mhlo::DotGeneralOp>(op, ret_sp_tensor.getType(),
+ op.getInputs()[0],
+ op.getInputs()[1], dot_dims,
+ /*defaultPrecision*/
+ ArrayAttr());
return success();
}
};
@@ -259,7 +190,6 @@
struct SparseDynSliceCallRewriter {
LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
assert(op.getResults().size() == 1 && "Need one output tensor");
-
auto ctx = op.getContext();
auto loc = op.getLoc();
auto retTp = op.getResults().getTypes()[0].cast<RankedTensorType>();
@@ -312,54 +242,6 @@
}
};
-struct SparseReshapeCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 1 && "Need one input tensor");
- assert(op.getResults().size() == 1 && "Need one output tensor");
-
- // Reconstruct the reshape operation.
- Value ret_sp_tensor = op.getResults()[0];
- // TODO(anlunx): Fix the issue that the reshape is rewritten to a collapse +
- // expand pair where the sparsity encoding is dropped in between.
- rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, ret_sp_tensor.getType(),
- op.getInputs()[0]);
- return success();
- }
-};
-
-static Value getEmptyTensor(OpBuilder& b, Location loc, RankedTensorType type) {
- auto t = b.create<tensor::EmptyOp>(loc, type.getShape(),
- type.getElementType(), ValueRange{});
- auto zero = b.getZeroAttr(type.getElementType());
- auto c0 = b.create<arith::ConstantOp>(loc, zero);
- return b.create<linalg::FillOp>(loc, ValueRange{c0}, ValueRange{t})
- .getResult(0);
-}
-
-struct SparseConvCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 2 && "Need two input tensors");
- assert(op.getResults().size() == 1 && "Need one output tensor");
-
- auto rtp = op.getResults()[0].getType().cast<RankedTensorType>();
-
- rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(
- op, op.getInputs(), getEmptyTensor(rewriter, op.getLoc(), rtp));
- return success();
- }
-};
-
-struct SparseConvertCallRewriter {
- LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
- assert(op.getInputs().size() == 1 && "Need one input tensor");
- assert(op.getResults().size() == 1 && "Need one output tensor");
- Value ret_sp_tensor = op.getResults()[0];
- rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
- op, ret_sp_tensor.getType(), op.getInputs()[0]);
- return success();
- }
-};
-
struct SparseReduceSumCallRewriter {
LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
assert(op.getInputs().size() == 2 && "Need one input tensor and axes");
@@ -394,27 +276,122 @@
rewriter.create<mhlo::AddOp>(loc, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(loc, addResult);
}
-
rewriter.replaceOp(op, reduce.getResults());
return success();
}
};
+struct SparseReshapeCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getInputs().size() == 1 && "Need one input tensor");
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ // Reconstruct the reshape operation.
+ Value ret_sp_tensor = op.getResults()[0];
+ // TODO(anlunx): Fix the issue that the reshape is rewritten to a collapse +
+ // expand pair where the sparsity encoding is dropped in between.
+ rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, ret_sp_tensor.getType(),
+ op.getInputs()[0]);
+ return success();
+ }
+};
+
+struct SparseSliceCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getInputs().size() == 4 &&
+ "Need one operand and three slicing parameters");
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ auto ctx = op.getContext();
+ auto loc = op.getLoc();
+ auto retTp = op.getResults().getTypes()[0].cast<RankedTensorType>();
+ auto offsets = getDenseIntAttrFromConstant(op.getInputs()[1]);
+ auto strides = getDenseIntAttrFromConstant(op.getInputs()[3]);
+ assert(offsets.getNumElements() == strides.getNumElements() &&
+ offsets.getNumElements() == retTp.getRank());
+ SmallVector<sparse_tensor::SparseTensorDimSliceAttr> slice_attrs;
+ SmallVector<int64_t> static_offsets, static_sizes, static_strides;
+ for (auto [offset, size, stride] :
+ llvm::zip(offsets, retTp.getShape(), strides)) {
+ int64_t o = offset.getZExtValue(), s = stride.getZExtValue();
+ // Converts limits to sizes.
+ slice_attrs.push_back(
+ sparse_tensor::SparseTensorDimSliceAttr::get(ctx, o, size, s));
+ static_offsets.push_back(o);
+ static_sizes.push_back(size);
+ static_strides.push_back(s);
+ }
+ auto srcEnc =
+ retTp.getEncoding().cast<sparse_tensor::SparseTensorEncodingAttr>();
+ // TODO(peiming): add a getSliceEncodingFrom into MLIR upstream.
+ auto sliceEnc = sparse_tensor::SparseTensorEncodingAttr::get(
+ ctx, srcEnc.getLvlTypes(), srcEnc.getDimToLvl(), srcEnc.getPosWidth(),
+ srcEnc.getCrdWidth(), slice_attrs);
+ auto sliceTp = RankedTensorType::get(retTp.getShape(),
+ retTp.getElementType(), sliceEnc);
+ auto slice = rewriter.create<tensor::ExtractSliceOp>(
+ loc, sliceTp, op.getInputs()[0], ValueRange(), ValueRange(),
+ ValueRange(), static_offsets, static_sizes, static_strides);
+ // TODO(peiming): This weakens the performance benefit we get from the
+ // sparse compiler by forcing every slice to be materialized while the
+ // sparse compiler supports view-based slice.
+ rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(op, retTp, slice);
+ return success();
+ }
+};
+
+struct SparseTransposeCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getInputs().size() == 2 && "Need argument and permutation");
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ // The permutation is passed in as a constant of dense int elements.
+ auto permutation_constant =
+ op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
+ auto permutation =
+ permutation_constant.getValue().cast<DenseIntElementsAttr>();
+ // Reconstruct the transpose operation.
+ Value ret_sp_tensor = op.getResults()[0];
+ rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
+ op, ret_sp_tensor.getType(), op.getInputs()[0], permutation);
+ return success();
+ }
+};
+
+template <typename unaryChlo>
+struct SparseUnaryChloCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getInputs().size() == 1 && "Need one argument");
+ assert(op.getResults().size() == 1 && "Need one output tensor");
+ // Reconstruct the unary chlo operation.
+ Value ret_sp_tensor = op.getResults()[0];
+ rewriter.replaceOpWithNewOp<unaryChlo>(op, ret_sp_tensor.getType(),
+ op.getInputs()[0]);
+ return success();
+ }
+};
+
+struct SparseUnpackCallRewriter {
+ LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+ assert(op.getResults().size() + 1 == op.getInputs().size());
+ // Both jax.BCSR and jax.BCOO has three memref fields.
+ SmallVector<Type, 3> unpack_ret_tp(op.getResults().getTypes());
+ Value tensor = op.getInputs()[0];
+ Value out_vals = op.getInputs()[1];
+ ValueRange out_lvls = op.getInputs().drop_front(2);
+ // Constructs the UnpackOp.
+ auto unpack_op = rewriter.create<sparse_tensor::UnpackOp>(
+ op.getLoc(), unpack_ret_tp, tensor, out_vals, out_lvls);
+ rewriter.replaceOp(op, unpack_op.getResults());
+ return success();
+ }
+};
+
class SparseCustomCallRewriter : public OpRewritePattern<mhlo::CustomCallOp> {
using OpRewritePattern<mhlo::CustomCallOp>::OpRewritePattern;
using SparseCustomTargetRewriter = std::function<LogicalResult(
mhlo::CustomCallOp op, PatternRewriter& rewriter)>;
const llvm::StringMap<SparseCustomTargetRewriter> rewriter_map_{
- std::make_pair("sparse_tensor_sparse_pack",
- SparseBatchedPackCallRewriter()),
- std::make_pair("sparse_tensor_sparse_unpack", SparseUnpackCallRewriter()),
- std::make_pair("sparse_tensor_transpose", SparseTransposeCallRewriter()),
- std::make_pair("sparse_tensor_dot_general", SparseDotCallRewriter()),
- std::make_pair("sparse_tensor_concatenate",
- SparseConcatenateCallRewriter()),
- std::make_pair("sparse_tensor_broadcast_in_dim",
- SparseBroadcastInDimCallRewriter()),
+ std::make_pair("sparse_tensor_add",
+ SparseBinaryCallRewriter<mhlo::AddOp>()),
std::make_pair("sparse_tensor_asin",
SparseUnaryChloCallRewriter<chlo::AsinOp>()),
std::make_pair("sparse_tensor_asinh",
@@ -425,24 +402,31 @@
SparseUnaryChloCallRewriter<chlo::AtanhOp>()),
std::make_pair("sparse_tensor_bessel_i1e",
SparseUnaryChloCallRewriter<chlo::BesselI1eOp>()),
- std::make_pair("sparse_tensor_sinh",
- SparseUnaryChloCallRewriter<chlo::SinhOp>()),
- std::make_pair("sparse_tensor_tan",
- SparseUnaryChloCallRewriter<chlo::TanOp>()),
- std::make_pair("sparse_tensor_slice", SparseSliceCallRewriter()),
- std::make_pair("sparse_tensor_dynamic_slice",
- SparseDynSliceCallRewriter()),
- std::make_pair("sparse_tensor_reshape", SparseReshapeCallRewriter()),
- std::make_pair("sparse_tensor_reduce_sum", SparseReduceSumCallRewriter()),
+ std::make_pair("sparse_tensor_broadcast_in_dim",
+ SparseBroadcastInDimCallRewriter()),
+ std::make_pair("sparse_tensor_concatenate",
+ SparseConcatenateCallRewriter()),
std::make_pair("sparse_tensor_conv_general_dilated",
SparseConvCallRewriter()),
std::make_pair("sparse_tensor_convert", SparseConvertCallRewriter()),
- std::make_pair("sparse_tensor_add",
- SparseBinaryCallRewriter<mhlo::AddOp>()),
- std::make_pair("sparse_tensor_sub",
- SparseBinaryCallRewriter<mhlo::SubtractOp>()),
+ std::make_pair("sparse_tensor_dot_general", SparseDotCallRewriter()),
+ std::make_pair("sparse_tensor_dynamic_slice",
+ SparseDynSliceCallRewriter()),
std::make_pair("sparse_tensor_mul",
SparseBinaryCallRewriter<mhlo::MulOp>()),
+ std::make_pair("sparse_tensor_reduce_sum", SparseReduceSumCallRewriter()),
+ std::make_pair("sparse_tensor_reshape", SparseReshapeCallRewriter()),
+ std::make_pair("sparse_tensor_sinh",
+ SparseUnaryChloCallRewriter<chlo::SinhOp>()),
+ std::make_pair("sparse_tensor_slice", SparseSliceCallRewriter()),
+ std::make_pair("sparse_tensor_sparse_pack",
+ SparseBatchedPackCallRewriter()),
+ std::make_pair("sparse_tensor_sparse_unpack", SparseUnpackCallRewriter()),
+ std::make_pair("sparse_tensor_sub",
+ SparseBinaryCallRewriter<mhlo::SubtractOp>()),
+ std::make_pair("sparse_tensor_tan",
+ SparseUnaryChloCallRewriter<chlo::TanOp>()),
+ std::make_pair("sparse_tensor_transpose", SparseTransposeCallRewriter()),
};
// Rewrites a CustomCallOp to corresponding sparse_tensor operation.