blob: 1d70d9b31b1b94131e5802fb09464d753c8879ee [file] [log] [blame]
// RUN: stablehlo-opt %s --tosa-legalize-stablehlo | FileCheck %s
// CHECK-LABEL: @abs
func.func @abs(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.abs
%0 = "stablehlo.abs"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @ceil
func.func @ceil(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.ceil
%0 = "stablehlo.ceil"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @convert
func.func @convert(%arg : tensor<10xi32>) -> tensor<10xf32> {
// CHECK: tosa.cast
%0 = "stablehlo.convert"(%arg) : (tensor<10xi32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @exponential
func.func @exponential(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.exp
%0 = "stablehlo.exponential"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @exponential_minus_one
func.func @exponential_minus_one(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.exp"(%arg0)
// CHECK-DAG: %[[VAR2:.*]] = "tosa.sub"(%[[VAR1]], %[[VAR0]])
%0 = "stablehlo.exponential_minus_one"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @floor
func.func @floor(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.floor
%0 = "stablehlo.floor"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @is_finite
func.func @is_finite(%arg : tensor<10xf32>) -> tensor<10xi1> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0x7F800000>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.abs"(%arg0)
// CHECK-DAG: %[[VAR2:.*]] = "tosa.equal"(%[[VAR1]], %[[VAR0]])
// CHECK-DAG: %[[VAR3:.*]] = "tosa.logical_not"(%[[VAR2]])
%0 = "stablehlo.is_finite"(%arg) : (tensor<10xf32>) -> tensor<10xi1>
return %0 : tensor<10xi1>
}
// CHECK-LABEL: @log
func.func @log(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.log
%0 = "stablehlo.log"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @log_plus_one
func.func @log_plus_one(%arg : tensor<10xf16>) -> tensor<10xf16> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
// CHECK-DAG: %[[VAR2:.*]] = "tosa.log"(%[[VAR1]])
%0 = "stablehlo.log_plus_one"(%arg) : (tensor<10xf16>) -> tensor<10xf16>
return %0 : tensor<10xf16>
}
// CHECK-LABEL: @negate
func.func @negate(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.negate
%0 = "stablehlo.negate"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @slice
func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> {
// CHECK: "tosa.slice"(%arg0) {size = array<i64: 2, 2>, start = array<i64: 2, 1>}
%0 = "stablehlo.slice"(%arg) {
start_indices = dense<[2, 1]> : tensor<2xi64>,
limit_indices = dense<[4, 3]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<4x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: @slice_stride_not_one
func.func @slice_stride_not_one(%arg : tensor<4x3xf32>) -> tensor<2x1xf32> {
// tosa.slice only supports strides of 1, so this should not legalize.
// CHECK: "stablehlo.slice"
%0 = "stablehlo.slice"(%arg) {
start_indices = dense<[2, 1]> : tensor<2xi64>,
limit_indices = dense<[4, 3]> : tensor<2xi64>,
strides = dense<[1, 2]> : tensor<2xi64>
} : (tensor<4x3xf32>) -> tensor<2x1xf32>
return %0 : tensor<2x1xf32>
}
// CHECK-LABEL: @slice_rank_seven
func.func @slice_rank_seven(%arg : tensor<2x3x4x5x6x7x8xf32>) -> tensor<1x2x3x4x5x6x7xf32> {
// tosa.slice only supports 1D to 6D tensors, so this should not legalize.
// CHECK: "stablehlo.slice"
%0 = "stablehlo.slice"(%arg) {
start_indices = dense<[1, 1, 1, 1, 1, 1, 1]> : tensor<7xi64>,
limit_indices = dense<[2, 3, 4, 5, 6, 7, 8]> : tensor<7xi64>,
strides = dense<[1, 1, 1, 1, 1, 1, 1]> : tensor<7xi64>
} : (tensor<2x3x4x5x6x7x8xf32>) -> tensor<1x2x3x4x5x6x7xf32>
return %0 : tensor<1x2x3x4x5x6x7xf32>
}
// CHECK-LABEL: @tanh
func.func @tanh(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.tanh
%0 = "stablehlo.tanh"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
}
// CHECK-LABEL: @transpose
func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
// CHECK-DAG: %[[VAR1:.*]] = "tosa.transpose"(%arg0, %[[VAR0]])
%0 = "stablehlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
return %0 : tensor<3x2x1xf32>
}
// CHECK-LABEL: @while
func.func @while(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<3> : tensor<i32>}
// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[VAR2:.*]] = "tosa.while_loop"(%arg0) ({
// CHECK: ^bb0(%[[ARG0:.+]]: tensor<i32>):
// CHECK: %[[VAR3:.*]] = "tosa.equal"(%[[ARG0]], %[[VAR0]])
// CHECK: "tosa.yield"(%[[VAR3]])
// CHECK: }, {
// CHECK: ^bb0(%[[ARG0:.+]]: tensor<i32>):
// CHECK: %[[VAR4:.*]] = "tosa.add"(%[[ARG0]], %[[VAR1]])
// CHECK: "tosa.yield"(%[[VAR4]])
// CHECK: }) : (tensor<i32>) -> tensor<i32>
// CHECK: return %[[VAR2]] : tensor<i32>
// CHECK: }
%0 = "stablehlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i32>):
%1 = "stablehlo.constant"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
%2 = "stablehlo.compare"(%arg1, %1) {comparison_direction = #stablehlo<comparison_direction EQ>}: (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>):
%1 = "stablehlo.constant"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "stablehlo.add"(%arg1, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%2) : (tensor<i32>) -> ()
}) : (tensor<i32>) -> (tensor<i32>)
return %0 : tensor<i32>
}