Add conversion for tosa equal (#336)

Co-authored-by: Marius Brehler <marius.brehler@iml.fraunhofer.de>
diff --git a/docs/tosa-op-coverage.md b/docs/tosa-op-coverage.md
index 4f4e822..b186f9c 100644
--- a/docs/tosa-op-coverage.md
+++ b/docs/tosa-op-coverage.md
@@ -23,6 +23,7 @@
 | **Binary elementwise ops**
 | add                    | :heavy_check_mark: | |
 | arithmetic_right_shift | :heavy_check_mark: | |
+| equal                  | :heavy_check_mark: | |
 | logical_left_shift     | :heavy_check_mark: | |
 | maximum                | :heavy_check_mark: | |
 | minimum                | :heavy_check_mark: | |
diff --git a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
index 8070e1d..3e3871e 100644
--- a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
+++ b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
@@ -768,6 +768,8 @@
                                                            "emitc::tosa::add");
   patterns.add<ArithmeticRightShiftOpConversion>(
       ctx, "emitc::tosa::arithmetic_right_shift");
+  patterns.add<CallOpBroadcastableConversion<tosa::EqualOp>>(
+      ctx, "emitc::tosa::equal", /*explicitResultType=*/true);
   patterns.add<CallOpBroadcastableConversion<tosa::LogicalLeftShiftOp>>(
       ctx, "emitc::tosa::logical_left_shift");
   patterns.add<CallOpBroadcastableConversion<tosa::MaximumOp>>(
@@ -844,6 +846,7 @@
     // Binary elementwise ops.
     target.addIllegalOp<tosa::AddOp,
                         tosa::ArithmeticRightShiftOp,
+                        tosa::EqualOp,
                         tosa::LogicalLeftShiftOp,
                         tosa::MaximumOp,
                         tosa::MinimumOp,
diff --git a/reference-implementation/include/emitc/tosa.h b/reference-implementation/include/emitc/tosa.h
index 8a1b46f..30610a2 100644
--- a/reference-implementation/include/emitc/tosa.h
+++ b/reference-implementation/include/emitc/tosa.h
@@ -181,6 +181,14 @@
   return binary<Src>(x, y, f);
 }
 
+// EqualOp
+template <typename Dest, typename Src>
+inline Dest equal(Src x, Src y) {
+  using ET_Src = typename get_element_type<Src>::type;
+  auto f = [](ET_Src left, ET_Src right) { return left == right; };
+  return binary<Dest, Src>(x, y, f);
+}
+
 // LogicalLeftShiftOp
 template <typename Src>
 inline Src logical_left_shift(Src x, Src y) {
diff --git a/reference-implementation/unittests/tosa.cpp b/reference-implementation/unittests/tosa.cpp
index fae491c..cf87e11 100644
--- a/reference-implementation/unittests/tosa.cpp
+++ b/reference-implementation/unittests/tosa.cpp
@@ -319,6 +319,14 @@
   }
 }
 
+TEST(tosa, equal) {
+  Tensor2D<int32_t, 2, 2> s0{3, 2, -1, 0};
+  Tensor2D<int32_t, 2, 2> t0{3, -2, 0, 0};
+  Tensor2D<bool, 2, 2> expected_result{true, false, false, true};
+  Tensor2D<bool, 2, 2> result = tosa::equal<Tensor<bool, 2, 2>>(s0, t0);
+  EXPECT_THAT(result, Pointwise(Eq(), expected_result));
+}
+
 TEST(tosa, logical_left_shift) {
   Tensor1D<int16_t, 4> s0{0b1, 0b1, -0b1, 0b101};
   Tensor1D<int16_t, 4> t0{0, 1, 1, 2};
diff --git a/test/Conversion/tosa-to-emitc.mlir b/test/Conversion/tosa-to-emitc.mlir
index 0a9e2d4..ac778de 100644
--- a/test/Conversion/tosa-to-emitc.mlir
+++ b/test/Conversion/tosa-to-emitc.mlir
@@ -125,6 +125,12 @@
   return %0 : tensor<13x21x3xi32>
 }
 
+func.func @test_equal(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi1> {
+  // CHECK: emitc.call "emitc::tosa::equal"(%arg0, %arg1) {template_args = [tensor<13x21x3xi1>]} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1>
+  %0 = "tosa.equal"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1>
+  return %0 : tensor<13x21x3xi1>
+}
+
 func.func @test_logical_left_shift(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
   // CHECK: emitc.call "emitc::broadcast_in_dim"(%arg0) {args = [0 : index, dense<[0, 1, 2]> : tensor<3xi64>], template_args = [tensor<13x21x3xi32>]} : (tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
   // CHECK: emitc.call "emitc::tosa::logical_left_shift"(%0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>