Add tosa.greater_equal lowering to EmitC (#341)
Co-authored-by: Marius Brehler <marius.brehler@iml.fraunhofer.de>
diff --git a/docs/tosa-op-coverage.md b/docs/tosa-op-coverage.md
index d8a903f..f4d7411 100644
--- a/docs/tosa-op-coverage.md
+++ b/docs/tosa-op-coverage.md
@@ -24,6 +24,7 @@
| add | :heavy_check_mark: | |
| arithmetic_right_shift | :heavy_check_mark: | |
| equal | :heavy_check_mark: | |
+| greater_equal | :heavy_check_mark: | |
| logical_left_shift | :heavy_check_mark: | |
| maximum | :heavy_check_mark: | |
| minimum | :heavy_check_mark: | |
diff --git a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
index 6bf0a0e..a169aca 100644
--- a/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
+++ b/lib/Conversion/TosaToEmitC/TosaToEmitC.cpp
@@ -883,6 +883,8 @@
ctx, "emitc::tosa::arithmetic_right_shift");
patterns.add<CallOpBroadcastableConversion<tosa::EqualOp>>(
ctx, "emitc::tosa::equal", /*explicitResultType=*/true);
+ patterns.add<CallOpBroadcastableConversion<tosa::GreaterEqualOp>>(
+ ctx, "emitc::tosa::greater_equal", /*explicitResultType=*/true);
patterns.add<CallOpBroadcastableConversion<tosa::LogicalLeftShiftOp>>(
ctx, "emitc::tosa::logical_left_shift");
patterns.add<CallOpBroadcastableConversion<tosa::MaximumOp>>(
@@ -967,6 +969,7 @@
target.addIllegalOp<tosa::AddOp,
tosa::ArithmeticRightShiftOp,
tosa::EqualOp,
+ tosa::GreaterEqualOp,
tosa::LogicalLeftShiftOp,
tosa::MaximumOp,
tosa::MinimumOp,
diff --git a/reference-implementation/include/emitc/tosa.h b/reference-implementation/include/emitc/tosa.h
index b81ae3b..368dda1 100644
--- a/reference-implementation/include/emitc/tosa.h
+++ b/reference-implementation/include/emitc/tosa.h
@@ -189,6 +189,14 @@
return binary<Dest, Src>(x, y, f);
}
+// GreaterEqualOp
+template <typename Dest, typename Src>
+inline Dest greater_equal(Src x, Src y) {
+ using ET_Src = typename get_element_type<Src>::type;
+ auto f = [](ET_Src left, ET_Src right) { return left >= right; };
+ return binary<Dest, Src>(x, y, f);
+}
+
// LogicalLeftShiftOp
template <typename Src>
inline Src logical_left_shift(Src x, Src y) {
diff --git a/reference-implementation/unittests/tosa.cpp b/reference-implementation/unittests/tosa.cpp
index 6a26e63..4c00f98 100644
--- a/reference-implementation/unittests/tosa.cpp
+++ b/reference-implementation/unittests/tosa.cpp
@@ -327,6 +327,15 @@
EXPECT_THAT(result, Pointwise(Eq(), expected_result));
}
+TEST(tosa, greater_equal) {
+ Tensor3D<int32_t, 2, 2, 1> s0{3, 2, -1, -5};
+ Tensor3D<int32_t, 2, 2, 1> t0{3, -2, 0, -3};
+ Tensor3D<bool, 2, 2, 1> expected_result{true, true, false, false};
+ Tensor3D<bool, 2, 2, 1> result =
+ tosa::greater_equal<Tensor<bool, 2, 2, 1>>(s0, t0);
+ EXPECT_THAT(result, Pointwise(Eq(), expected_result));
+}
+
TEST(tosa, logical_left_shift) {
Tensor1D<int16_t, 4> s0{0b1, 0b1, -0b1, 0b101};
Tensor1D<int16_t, 4> t0{0, 1, 1, 2};
diff --git a/test/Conversion/tosa-to-emitc.mlir b/test/Conversion/tosa-to-emitc.mlir
index 2637162..c037290 100644
--- a/test/Conversion/tosa-to-emitc.mlir
+++ b/test/Conversion/tosa-to-emitc.mlir
@@ -131,6 +131,12 @@
return %0 : tensor<13x21x3xi1>
}
+func.func @test_greater_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ // CHECK: emitc.call "emitc::tosa::greater_equal"(%arg0, %arg1) {template_args = [tensor<13x21x3xi1>]} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
func.func @test_logical_left_shift(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// CHECK: emitc.call "emitc::broadcast_in_dim"(%arg0) {args = [0 : index, dense<[0, 1, 2]> : tensor<3xi64>], template_args = [tensor<13x21x3xi32>]} : (tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
// CHECK: emitc.call "emitc::tosa::logical_left_shift"(%0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>