Add conversion for tosa equal (#336)
Co-authored-by: Marius Brehler <marius.brehler@iml.fraunhofer.de>
diff --git a/docs/tosa-op-coverage.md b/docs/tosa-op-coverage.md
index 4f4e822..b186f9c 100644
--- a/docs/tosa-op-coverage.md
+++ b/docs/tosa-op-coverage.md
@@ -23,6 +23,7 @@
| **Binary elementwise ops**
| add | :heavy_check_mark: | |
| arithmetic_right_shift | :heavy_check_mark: | |
+| equal | :heavy_check_mark: | |
| logical_left_shift | :heavy_check_mark: | |
| maximum | :heavy_check_mark: | |
| minimum | :heavy_check_mark: | |
diff --git a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
index 8070e1d..3e3871e 100644
--- a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
+++ b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
@@ -768,6 +768,8 @@
"emitc::tosa::add");
patterns.add<ArithmeticRightShiftOpConversion>(
ctx, "emitc::tosa::arithmetic_right_shift");
+ patterns.add<CallOpBroadcastableConversion<tosa::EqualOp>>(
+ ctx, "emitc::tosa::equal", /*explicitResultType=*/true);
patterns.add<CallOpBroadcastableConversion<tosa::LogicalLeftShiftOp>>(
ctx, "emitc::tosa::logical_left_shift");
patterns.add<CallOpBroadcastableConversion<tosa::MaximumOp>>(
@@ -844,6 +846,7 @@
// Binary elementwise ops.
target.addIllegalOp<tosa::AddOp,
tosa::ArithmeticRightShiftOp,
+ tosa::EqualOp,
tosa::LogicalLeftShiftOp,
tosa::MaximumOp,
tosa::MinimumOp,
diff --git a/reference-implementation/include/emitc/tosa.h b/reference-implementation/include/emitc/tosa.h
index 8a1b46f..30610a2 100644
--- a/reference-implementation/include/emitc/tosa.h
+++ b/reference-implementation/include/emitc/tosa.h
@@ -181,6 +181,14 @@
return binary<Src>(x, y, f);
}
+// EqualOp
+template <typename Dest, typename Src>
+inline Dest equal(Src x, Src y) {
+ using ET_Src = typename get_element_type<Src>::type;
+ auto f = [](ET_Src left, ET_Src right) { return left == right; };
+ return binary<Dest, Src>(x, y, f);
+}
+
// LogicalLeftShiftOp
template <typename Src>
inline Src logical_left_shift(Src x, Src y) {
diff --git a/reference-implementation/unittests/tosa.cpp b/reference-implementation/unittests/tosa.cpp
index fae491c..cf87e11 100644
--- a/reference-implementation/unittests/tosa.cpp
+++ b/reference-implementation/unittests/tosa.cpp
@@ -319,6 +319,14 @@
}
}
+TEST(tosa, equal) {
+ Tensor2D<int32_t, 2, 2> s0{3, 2, -1, 0};
+ Tensor2D<int32_t, 2, 2> t0{3, -2, 0, 0};
+ Tensor2D<bool, 2, 2> expected_result{true, false, false, true};
+ Tensor2D<bool, 2, 2> result = tosa::equal<Tensor<bool, 2, 2>>(s0, t0);
+ EXPECT_THAT(result, Pointwise(Eq(), expected_result));
+}
+
TEST(tosa, logical_left_shift) {
Tensor1D<int16_t, 4> s0{0b1, 0b1, -0b1, 0b101};
Tensor1D<int16_t, 4> t0{0, 1, 1, 2};
diff --git a/test/Conversion/tosa-to-emitc.mlir b/test/Conversion/tosa-to-emitc.mlir
index 0a9e2d4..ac778de 100644
--- a/test/Conversion/tosa-to-emitc.mlir
+++ b/test/Conversion/tosa-to-emitc.mlir
@@ -125,6 +125,12 @@
return %0 : tensor<13x21x3xi32>
}
+func.func @test_equal(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi1> {
+ // CHECK: emitc.call "emitc::tosa::equal"(%arg0, %arg1) {template_args = [tensor<13x21x3xi1>]} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1>
+ %0 = "tosa.equal"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
func.func @test_logical_left_shift(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// CHECK: emitc.call "emitc::broadcast_in_dim"(%arg0) {args = [0 : index, dense<[0, 1, 2]> : tensor<3xi64>], template_args = [tensor<13x21x3xi32>]} : (tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
// CHECK: emitc.call "emitc::tosa::logical_left_shift"(%0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>