[xla-next][sparse] sort the sparse custom call rewriter in alphabet order

PiperOrigin-RevId: 537905129
GitOrigin-RevId: 5c17bd6a558922fce92a5bb3f9457bb8fb73165a
Change-Id: I96c50a84508b4c85c39af440672e868f43daaf6c
diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc
index 40eef9f..cedfb9f 100644
--- a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc
+++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/sparse_rewrite_passes.cc
@@ -70,6 +70,15 @@
   values.append(range.begin(), range.end());
 }
 
+Value getEmptyTensor(OpBuilder& b, Location loc, RankedTensorType type) {
+  auto t = b.create<tensor::EmptyOp>(loc, type.getShape(),
+                                     type.getElementType(), ValueRange{});
+  auto zero = b.getZeroAttr(type.getElementType());
+  auto c0 = b.create<arith::ConstantOp>(loc, zero);
+  return b.create<linalg::FillOp>(loc, ValueRange{c0}, ValueRange{t})
+      .getResult(0);
+}
+
 struct SparseBatchedPackCallRewriter {
   LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
     assert(op.getResults().size() == 1 && "Must be packing into one tensor");
@@ -81,118 +90,6 @@
   }
 };
 
-struct SparseUnpackCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getResults().size() + 1 == op.getInputs().size());
-    // Both jax.BCSR and jax.BCOO has three memref fields.
-    SmallVector<Type, 3> unpack_ret_tp(op.getResults().getTypes());
-    Value tensor = op.getInputs()[0];
-    Value out_vals = op.getInputs()[1];
-    ValueRange out_lvls = op.getInputs().drop_front(2);
-    // Constructs the UnpackOp.
-    auto unpack_op = rewriter.create<sparse_tensor::UnpackOp>(
-        op.getLoc(), unpack_ret_tp, tensor, out_vals, out_lvls);
-    rewriter.replaceOp(op, unpack_op.getResults());
-    return success();
-  }
-};
-
-struct SparseTransposeCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 2 && "Need argument and permutation");
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-
-    // The permutation is passed in as a constant of dense int elements.
-    auto permutation_constant =
-        op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
-    auto permutation =
-        permutation_constant.getValue().cast<DenseIntElementsAttr>();
-
-    // Reconstruct the transpose operation.
-    Value ret_sp_tensor = op.getResults()[0];
-    rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
-        op, ret_sp_tensor.getType(), op.getInputs()[0], permutation);
-    return success();
-  }
-};
-
-struct SparseDotCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 6 && "Need arguments and metadata");
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-    SmallVector<int64_t> lhs_contr, rhs_contr, lhs_batch, rhs_batch;
-    getIntegersFromDenseElements(op.getInputs()[2], lhs_contr);
-    getIntegersFromDenseElements(op.getInputs()[3], rhs_contr);
-    getIntegersFromDenseElements(op.getInputs()[4], lhs_batch);
-    getIntegersFromDenseElements(op.getInputs()[5], rhs_batch);
-    auto dot_dims = mlir::mhlo::DotDimensionNumbersAttr::get(
-        op.getContext(), lhs_batch, rhs_batch, lhs_contr, rhs_contr);
-    Value ret_sp_tensor = op.getResults()[0];
-    rewriter.replaceOpWithNewOp<mhlo::DotGeneralOp>(
-        op, ret_sp_tensor.getType(), op.getInputs()[0], op.getInputs()[1],
-        dot_dims, /*defaultPrecision*/ ArrayAttr());
-    return success();
-  }
-};
-
-struct SparseConcatenateCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-
-    // The concatenation dimension.
-    auto concat_dim = op.getInputs().back().getDefiningOp<mhlo::ConstantOp>();
-    auto concat_dim_attr = concat_dim.getValue().cast<DenseIntElementsAttr>();
-    // Reconstruct the concatenate operation.
-    Value ret_sp_tensor = op.getResults()[0];
-    // Depending on test setup, we can get either a 32-bit integer or a 64-bit
-    // integer.
-    if (concat_dim_attr.getElementType().isInteger(32)) {
-      rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
-          op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
-          rewriter.getIndexAttr(concat_dim_attr.getValues<uint32_t>()[0]));
-    } else {
-      assert(concat_dim_attr.getElementType().isInteger(64));
-      rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
-          op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
-          rewriter.getIndexAttr(concat_dim_attr.getValues<uint64_t>()[0]));
-    }
-
-    return success();
-  }
-};
-
-struct SparseBroadcastInDimCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 2 &&
-           "Need argument and broadcast dimensions");
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-
-    // Broadcast dimensions are passed in as a constant of dense int elements.
-    auto dims_constant = op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
-    auto broadcast_dimensions =
-        dims_constant.getValue().cast<DenseIntElementsAttr>();
-
-    // Reconstruct the broadcast_in_dim operation.
-    Value ret_sp_tensor = op.getResults()[0];
-    rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(
-        op, ret_sp_tensor.getType(), op.getInputs()[0], broadcast_dimensions);
-    return success();
-  }
-};
-
-template <typename unaryChlo>
-struct SparseUnaryChloCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 1 && "Need one argument");
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-    // Reconstruct the unary chlo operation.
-    Value ret_sp_tensor = op.getResults()[0];
-    rewriter.replaceOpWithNewOp<unaryChlo>(op, ret_sp_tensor.getType(),
-                                           op.getInputs()[0]);
-    return success();
-  }
-};
-
 template <typename BinaryMhlo>
 struct SparseBinaryCallRewriter {
   LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
@@ -206,52 +103,86 @@
   }
 };
 
-struct SparseSliceCallRewriter {
+struct SparseBroadcastInDimCallRewriter {
   LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 4 &&
-           "Need one operand and three slicing parameters");
+    assert(op.getInputs().size() == 2 &&
+           "Need argument and broadcast dimensions");
     assert(op.getResults().size() == 1 && "Need one output tensor");
+    // Broadcast dimensions are passed in as a constant of dense int elements.
+    auto dims_constant = op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
+    auto broadcast_dimensions =
+        dims_constant.getValue().cast<DenseIntElementsAttr>();
+    // Reconstruct the broadcast_in_dim operation.
+    Value ret_sp_tensor = op.getResults()[0];
+    rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(
+        op, ret_sp_tensor.getType(), op.getInputs()[0], broadcast_dimensions);
+    return success();
+  }
+};
 
-    auto ctx = op.getContext();
-    auto loc = op.getLoc();
-    auto retTp = op.getResults().getTypes()[0].cast<RankedTensorType>();
-
-    auto offsets = getDenseIntAttrFromConstant(op.getInputs()[1]);
-    auto strides = getDenseIntAttrFromConstant(op.getInputs()[3]);
-
-    assert(offsets.getNumElements() == strides.getNumElements() &&
-           offsets.getNumElements() == retTp.getRank());
-
-    SmallVector<sparse_tensor::SparseTensorDimSliceAttr> slice_attrs;
-    SmallVector<int64_t> static_offsets, static_sizes, static_strides;
-    for (auto [offset, size, stride] :
-         llvm::zip(offsets, retTp.getShape(), strides)) {
-      int64_t o = offset.getZExtValue(), s = stride.getZExtValue();
-      // Converts limits to sizes.
-      slice_attrs.push_back(
-          sparse_tensor::SparseTensorDimSliceAttr::get(ctx, o, size, s));
-      static_offsets.push_back(o);
-      static_sizes.push_back(size);
-      static_strides.push_back(s);
+struct SparseConcatenateCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    // The concatenation dimension.
+    auto concat_dim = op.getInputs().back().getDefiningOp<mhlo::ConstantOp>();
+    auto concat_dim_attr = concat_dim.getValue().cast<DenseIntElementsAttr>();
+    // Reconstruct the concatenate operation.
+    Value ret_sp_tensor = op.getResults()[0];
+    // Depending on test setup, we can get either a 32-bit integer or a 64-bit
+    // integer.
+    if (concat_dim_attr.getElementType().isInteger(32)) {
+      rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
+          op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
+          rewriter.getIndexAttr(concat_dim_attr.getValues<uint32_t>()[0]));
+    } else {
+      assert(concat_dim_attr.getElementType().isInteger(64));
+      rewriter.replaceOpWithNewOp<sparse_tensor::ConcatenateOp>(
+          op, ret_sp_tensor.getType(), op.getInputs().drop_back(),
+          rewriter.getIndexAttr(concat_dim_attr.getValues<uint64_t>()[0]));
     }
+    return success();
+  }
+};
 
-    auto srcEnc =
-        retTp.getEncoding().cast<sparse_tensor::SparseTensorEncodingAttr>();
-    // TODO(peiming): add a getSliceEncodingFrom into MLIR upstream.
-    auto sliceEnc = sparse_tensor::SparseTensorEncodingAttr::get(
-        ctx, srcEnc.getLvlTypes(), srcEnc.getDimToLvl(), srcEnc.getPosWidth(),
-        srcEnc.getCrdWidth(), slice_attrs);
-    auto sliceTp = RankedTensorType::get(retTp.getShape(),
-                                         retTp.getElementType(), sliceEnc);
+struct SparseConvCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getInputs().size() == 2 && "Need two input tensors");
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    auto rtp = op.getResults()[0].getType().cast<RankedTensorType>();
+    rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(
+        op, op.getInputs(), getEmptyTensor(rewriter, op.getLoc(), rtp));
+    return success();
+  }
+};
 
-    auto slice = rewriter.create<tensor::ExtractSliceOp>(
-        loc, sliceTp, op.getInputs()[0], ValueRange(), ValueRange(),
-        ValueRange(), static_offsets, static_sizes, static_strides);
+struct SparseConvertCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getInputs().size() == 1 && "Need one input tensor");
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    Value ret_sp_tensor = op.getResults()[0];
+    rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
+        op, ret_sp_tensor.getType(), op.getInputs()[0]);
+    return success();
+  }
+};
 
-    // TODO(peiming): This weakens the performance benefit we get from the
-    // sparse compiler by forcing every slice to be materizalized while the
-    // sparse compiler supports view-based slice.
-    rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(op, retTp, slice);
+struct SparseDotCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getInputs().size() == 6 && "Need arguments and metadata");
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    SmallVector<int64_t> lhs_contr, rhs_contr, lhs_batch, rhs_batch;
+    getIntegersFromDenseElements(op.getInputs()[2], lhs_contr);
+    getIntegersFromDenseElements(op.getInputs()[3], rhs_contr);
+    getIntegersFromDenseElements(op.getInputs()[4], lhs_batch);
+    getIntegersFromDenseElements(op.getInputs()[5], rhs_batch);
+    auto dot_dims = mlir::mhlo::DotDimensionNumbersAttr::get(
+        op.getContext(), lhs_batch, rhs_batch, lhs_contr, rhs_contr);
+    Value ret_sp_tensor = op.getResults()[0];
+    rewriter.replaceOpWithNewOp<mhlo::DotGeneralOp>(op, ret_sp_tensor.getType(),
+                                                    op.getInputs()[0],
+                                                    op.getInputs()[1], dot_dims,
+                                                    /*defaultPrecision*/
+                                                    ArrayAttr());
     return success();
   }
 };
@@ -259,7 +190,6 @@
 struct SparseDynSliceCallRewriter {
   LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
     assert(op.getResults().size() == 1 && "Need one output tensor");
-
     auto ctx = op.getContext();
     auto loc = op.getLoc();
     auto retTp = op.getResults().getTypes()[0].cast<RankedTensorType>();
@@ -312,54 +242,6 @@
   }
 };
 
-struct SparseReshapeCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 1 && "Need one input tensor");
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-
-    // Reconstruct the reshape operation.
-    Value ret_sp_tensor = op.getResults()[0];
-    // TODO(anlunx): Fix the issue that the reshape is rewritten to a collapse +
-    // expand pair where the sparsity encoding is dropped in between.
-    rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, ret_sp_tensor.getType(),
-                                                 op.getInputs()[0]);
-    return success();
-  }
-};
-
-static Value getEmptyTensor(OpBuilder& b, Location loc, RankedTensorType type) {
-  auto t = b.create<tensor::EmptyOp>(loc, type.getShape(),
-                                     type.getElementType(), ValueRange{});
-  auto zero = b.getZeroAttr(type.getElementType());
-  auto c0 = b.create<arith::ConstantOp>(loc, zero);
-  return b.create<linalg::FillOp>(loc, ValueRange{c0}, ValueRange{t})
-      .getResult(0);
-}
-
-struct SparseConvCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 2 && "Need two input tensors");
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-
-    auto rtp = op.getResults()[0].getType().cast<RankedTensorType>();
-
-    rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(
-        op, op.getInputs(), getEmptyTensor(rewriter, op.getLoc(), rtp));
-    return success();
-  }
-};
-
-struct SparseConvertCallRewriter {
-  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
-    assert(op.getInputs().size() == 1 && "Need one input tensor");
-    assert(op.getResults().size() == 1 && "Need one output tensor");
-    Value ret_sp_tensor = op.getResults()[0];
-    rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
-        op, ret_sp_tensor.getType(), op.getInputs()[0]);
-    return success();
-  }
-};
-
 struct SparseReduceSumCallRewriter {
   LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
     assert(op.getInputs().size() == 2 && "Need one input tensor and axes");
@@ -394,27 +276,122 @@
           rewriter.create<mhlo::AddOp>(loc, *firstArgument, *secondArgument);
       rewriter.create<mhlo::ReturnOp>(loc, addResult);
     }
-
     rewriter.replaceOp(op, reduce.getResults());
     return success();
   }
 };
 
+struct SparseReshapeCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getInputs().size() == 1 && "Need one input tensor");
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    // Reconstruct the reshape operation.
+    Value ret_sp_tensor = op.getResults()[0];
+    // TODO(anlunx): Fix the issue that the reshape is rewritten to a collapse +
+    // expand pair where the sparsity encoding is dropped in between.
+    rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, ret_sp_tensor.getType(),
+                                                 op.getInputs()[0]);
+    return success();
+  }
+};
+
+struct SparseSliceCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getInputs().size() == 4 &&
+           "Need one operand and three slicing parameters");
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    auto ctx = op.getContext();
+    auto loc = op.getLoc();
+    auto retTp = op.getResults().getTypes()[0].cast<RankedTensorType>();
+    auto offsets = getDenseIntAttrFromConstant(op.getInputs()[1]);
+    auto strides = getDenseIntAttrFromConstant(op.getInputs()[3]);
+    assert(offsets.getNumElements() == strides.getNumElements() &&
+           offsets.getNumElements() == retTp.getRank());
+    SmallVector<sparse_tensor::SparseTensorDimSliceAttr> slice_attrs;
+    SmallVector<int64_t> static_offsets, static_sizes, static_strides;
+    for (auto [offset, size, stride] :
+         llvm::zip(offsets, retTp.getShape(), strides)) {
+      int64_t o = offset.getZExtValue(), s = stride.getZExtValue();
+      // Converts limits to sizes.
+      slice_attrs.push_back(
+          sparse_tensor::SparseTensorDimSliceAttr::get(ctx, o, size, s));
+      static_offsets.push_back(o);
+      static_sizes.push_back(size);
+      static_strides.push_back(s);
+    }
+    auto srcEnc =
+        retTp.getEncoding().cast<sparse_tensor::SparseTensorEncodingAttr>();
+    // TODO(peiming): add a getSliceEncodingFrom into MLIR upstream.
+    auto sliceEnc = sparse_tensor::SparseTensorEncodingAttr::get(
+        ctx, srcEnc.getLvlTypes(), srcEnc.getDimToLvl(), srcEnc.getPosWidth(),
+        srcEnc.getCrdWidth(), slice_attrs);
+    auto sliceTp = RankedTensorType::get(retTp.getShape(),
+                                         retTp.getElementType(), sliceEnc);
+    auto slice = rewriter.create<tensor::ExtractSliceOp>(
+        loc, sliceTp, op.getInputs()[0], ValueRange(), ValueRange(),
+        ValueRange(), static_offsets, static_sizes, static_strides);
+    // TODO(peiming): This weakens the performance benefit we get from the
+    // sparse compiler by forcing every slice to be materialized while the
+    // sparse compiler supports view-based slice.
+    rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(op, retTp, slice);
+    return success();
+  }
+};
+
+struct SparseTransposeCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getInputs().size() == 2 && "Need argument and permutation");
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    // The permutation is passed in as a constant of dense int elements.
+    auto permutation_constant =
+        op.getInputs()[1].getDefiningOp<mhlo::ConstantOp>();
+    auto permutation =
+        permutation_constant.getValue().cast<DenseIntElementsAttr>();
+    // Reconstruct the transpose operation.
+    Value ret_sp_tensor = op.getResults()[0];
+    rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
+        op, ret_sp_tensor.getType(), op.getInputs()[0], permutation);
+    return success();
+  }
+};
+
+template <typename unaryChlo>
+struct SparseUnaryChloCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getInputs().size() == 1 && "Need one argument");
+    assert(op.getResults().size() == 1 && "Need one output tensor");
+    // Reconstruct the unary chlo operation.
+    Value ret_sp_tensor = op.getResults()[0];
+    rewriter.replaceOpWithNewOp<unaryChlo>(op, ret_sp_tensor.getType(),
+                                           op.getInputs()[0]);
+    return success();
+  }
+};
+
+struct SparseUnpackCallRewriter {
+  LogicalResult operator()(mhlo::CustomCallOp op, PatternRewriter& rewriter) {
+    assert(op.getResults().size() + 1 == op.getInputs().size());
+    // Both jax.BCSR and jax.BCOO has three memref fields.
+    SmallVector<Type, 3> unpack_ret_tp(op.getResults().getTypes());
+    Value tensor = op.getInputs()[0];
+    Value out_vals = op.getInputs()[1];
+    ValueRange out_lvls = op.getInputs().drop_front(2);
+    // Constructs the UnpackOp.
+    auto unpack_op = rewriter.create<sparse_tensor::UnpackOp>(
+        op.getLoc(), unpack_ret_tp, tensor, out_vals, out_lvls);
+    rewriter.replaceOp(op, unpack_op.getResults());
+    return success();
+  }
+};
+
 class SparseCustomCallRewriter : public OpRewritePattern<mhlo::CustomCallOp> {
   using OpRewritePattern<mhlo::CustomCallOp>::OpRewritePattern;
   using SparseCustomTargetRewriter = std::function<LogicalResult(
       mhlo::CustomCallOp op, PatternRewriter& rewriter)>;
 
   const llvm::StringMap<SparseCustomTargetRewriter> rewriter_map_{
-      std::make_pair("sparse_tensor_sparse_pack",
-                     SparseBatchedPackCallRewriter()),
-      std::make_pair("sparse_tensor_sparse_unpack", SparseUnpackCallRewriter()),
-      std::make_pair("sparse_tensor_transpose", SparseTransposeCallRewriter()),
-      std::make_pair("sparse_tensor_dot_general", SparseDotCallRewriter()),
-      std::make_pair("sparse_tensor_concatenate",
-                     SparseConcatenateCallRewriter()),
-      std::make_pair("sparse_tensor_broadcast_in_dim",
-                     SparseBroadcastInDimCallRewriter()),
+      std::make_pair("sparse_tensor_add",
+                     SparseBinaryCallRewriter<mhlo::AddOp>()),
       std::make_pair("sparse_tensor_asin",
                      SparseUnaryChloCallRewriter<chlo::AsinOp>()),
       std::make_pair("sparse_tensor_asinh",
@@ -425,24 +402,31 @@
                      SparseUnaryChloCallRewriter<chlo::AtanhOp>()),
       std::make_pair("sparse_tensor_bessel_i1e",
                      SparseUnaryChloCallRewriter<chlo::BesselI1eOp>()),
-      std::make_pair("sparse_tensor_sinh",
-                     SparseUnaryChloCallRewriter<chlo::SinhOp>()),
-      std::make_pair("sparse_tensor_tan",
-                     SparseUnaryChloCallRewriter<chlo::TanOp>()),
-      std::make_pair("sparse_tensor_slice", SparseSliceCallRewriter()),
-      std::make_pair("sparse_tensor_dynamic_slice",
-                     SparseDynSliceCallRewriter()),
-      std::make_pair("sparse_tensor_reshape", SparseReshapeCallRewriter()),
-      std::make_pair("sparse_tensor_reduce_sum", SparseReduceSumCallRewriter()),
+      std::make_pair("sparse_tensor_broadcast_in_dim",
+                     SparseBroadcastInDimCallRewriter()),
+      std::make_pair("sparse_tensor_concatenate",
+                     SparseConcatenateCallRewriter()),
       std::make_pair("sparse_tensor_conv_general_dilated",
                      SparseConvCallRewriter()),
       std::make_pair("sparse_tensor_convert", SparseConvertCallRewriter()),
-      std::make_pair("sparse_tensor_add",
-                     SparseBinaryCallRewriter<mhlo::AddOp>()),
-      std::make_pair("sparse_tensor_sub",
-                     SparseBinaryCallRewriter<mhlo::SubtractOp>()),
+      std::make_pair("sparse_tensor_dot_general", SparseDotCallRewriter()),
+      std::make_pair("sparse_tensor_dynamic_slice",
+                     SparseDynSliceCallRewriter()),
       std::make_pair("sparse_tensor_mul",
                      SparseBinaryCallRewriter<mhlo::MulOp>()),
+      std::make_pair("sparse_tensor_reduce_sum", SparseReduceSumCallRewriter()),
+      std::make_pair("sparse_tensor_reshape", SparseReshapeCallRewriter()),
+      std::make_pair("sparse_tensor_sinh",
+                     SparseUnaryChloCallRewriter<chlo::SinhOp>()),
+      std::make_pair("sparse_tensor_slice", SparseSliceCallRewriter()),
+      std::make_pair("sparse_tensor_sparse_pack",
+                     SparseBatchedPackCallRewriter()),
+      std::make_pair("sparse_tensor_sparse_unpack", SparseUnpackCallRewriter()),
+      std::make_pair("sparse_tensor_sub",
+                     SparseBinaryCallRewriter<mhlo::SubtractOp>()),
+      std::make_pair("sparse_tensor_tan",
+                     SparseUnaryChloCallRewriter<chlo::TanOp>()),
+      std::make_pair("sparse_tensor_transpose", SparseTransposeCallRewriter()),
   };
 
   // Rewrites a CustomCallOp to corresponding sparse_tensor operation.