[tosa] Add path to TOSA backed backends. (#1466)

Connect these two industry standards by way of dialect legalization from
StableHLO to TOSA. This adds basic support and testing: along with usage
of some of the equivalent canonicalization patterns from MHLO (not
included in this PR) this has been sufficient for some full models
starting from ML framework to TOSA backed. Support is not complete and
partly relies on some canonical StableHLO forms.

The legalizations are also written primarily using PDLL, but we have not
yet adopted some of the newer support there for variadics. This work
started by targeting MHLO in TensorFlow repo as StableHLO was still
young, but given StableHLO development it makes more sense to instead
start there and provide a connection for community backends.

No new repository dependency is introduced. The cmake config enables
disabling building conversion,

---------

Co-authored-by: Eugene Burmako <burmako@google.com>
diff --git a/BUILD.bazel b/BUILD.bazel
index 1ab61e9..caf3999 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+load("@bazel_skylib//rules:build_test.bzl", "build_test")
 load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library")
 
 package(
@@ -825,9 +826,11 @@
     deps = [
         ":register",
         ":stablehlo_passes",
+        "//stablehlo/conversions/tosa/transforms:StablehloTOSATransforms",
         "//stablehlo/tests:test_utils",
         "@llvm-project//mlir:AllPassesAndDialects",
         "@llvm-project//mlir:MlirOptLib",
+        "@llvm-project//mlir:TosaDialect",
     ],
 )
 
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 25f1654..3a9d475 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -119,3 +119,4 @@
 add_custom_target(check-stablehlo)
 
 add_subdirectory(stablehlo)
+
diff --git a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt
index 87293a6..69d4196 100644
--- a/stablehlo/CMakeLists.txt
+++ b/stablehlo/CMakeLists.txt
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 add_subdirectory(api)
+add_subdirectory(conversions)
 add_subdirectory(dialect)
 add_subdirectory(integrations)
 add_subdirectory(reference)
diff --git a/stablehlo/conversions/CMakeLists.txt b/stablehlo/conversions/CMakeLists.txt
new file mode 100644
index 0000000..5b64ba3
--- /dev/null
+++ b/stablehlo/conversions/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(tosa)
diff --git a/stablehlo/conversions/tosa/CMakeLists.txt b/stablehlo/conversions/tosa/CMakeLists.txt
new file mode 100644
index 0000000..b5f42d6
--- /dev/null
+++ b/stablehlo/conversions/tosa/CMakeLists.txt
@@ -0,0 +1,22 @@
+# Copyright 2022 OpenXLA Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_subdirectory(transforms)
+add_subdirectory(tests)
+
+add_mlir_pdll_library(StablehloTOSAPDLLPatternsIncGen
+  transforms/legalize_stablehlo.pdll
+  transforms/legalize_stablehlo.pdll.h.inc
+  )
+
diff --git a/stablehlo/conversions/tosa/README.md b/stablehlo/conversions/tosa/README.md
new file mode 100644
index 0000000..797eebf
--- /dev/null
+++ b/stablehlo/conversions/tosa/README.md
@@ -0,0 +1,4 @@
+# StableHLO to TOSA legalization
+
+This module contains the [MLIR](https://mlir.llvm.org) utilities for
+legalization and interop between StableHLO and TOSA.
diff --git a/stablehlo/conversions/tosa/tests/CMakeLists.txt b/stablehlo/conversions/tosa/tests/CMakeLists.txt
new file mode 100644
index 0000000..0e888c1
--- /dev/null
+++ b/stablehlo/conversions/tosa/tests/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Copyright 2022 OpenXLA Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+configure_lit_site_cfg(
+        ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
+        ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
+        MAIN_CONFIG
+        ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
+)
+
+add_lit_testsuite(check-stablehlo-tosa-lit "Running the stablehlo-tosa regression tests"
+        ${CMAKE_CURRENT_BINARY_DIR}
+        DEPENDS
+        FileCheck
+        stablehlo-opt
+        )
+# add_dependencies(check-stablehlo-tests check-stablehlo-tosa-lit)
diff --git a/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/conversions/tosa/tests/binary.mlir
new file mode 100644
index 0000000..6daebc1
--- /dev/null
+++ b/stablehlo/conversions/tosa/tests/binary.mlir
@@ -0,0 +1,227 @@
+// RUN: stablehlo-opt %s --tosa-legalize-stablehlo | FileCheck %s
+
+// CHECK-LABEL: @add
+func.func @add(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.add
+  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @and
+func.func @and(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> {
+  // CHECK: tosa.bitwise_and
+  %0 = "stablehlo.and"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %0 : tensor<10xi32>
+}
+
+// CHECK-LABEL: @compare_eq
+func.func @compare_eq(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xi1> {
+  // CHECK: tosa.equal
+  %0 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction EQ>} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  return %0 : tensor<10xi1>
+}
+
+// CHECK-LABEL: @compare_lt
+func.func @compare_lt(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xi1> {
+  // CHECK: stablehlo.compare
+  %0 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction LT>} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  return %0 : tensor<10xi1>
+}
+
+// CHECK-LABEL: @compare_ne
+func.func @compare_ne(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi1> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.equal"(%arg0, %arg1)
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.logical_not"(%[[VAR0]])
+  %0 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction NE>} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  return %0 : tensor<10xi1>
+}
+
+// CHECK-LABEL: @concatenate
+func.func @concatenate(%arg0 : tensor<3x3xf32>, %arg1 : tensor<3x3xf32>) -> tensor<6x3xf32> {
+  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
+  %0 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
+  return %0 : tensor<6x3xf32>
+}
+
+// CHECK-LABEL: @divide
+func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> {
+  // CHECK: tosa.div
+  %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %0 : tensor<10xi32>
+}
+
+// CHECK-LABEL: @divide_f32
+func.func @divide_f32(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
+  // tosa.div only supports i32, so this should not legalize.
+  // CHECK: stablehlo.divide
+  %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @dot_vector_vector
+func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor<f32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 3>}
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 3, 1>}
+  // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]])
+  // CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]])
+  %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: @dot_vector_matrix
+func.func @dot_vector_matrix(%arg0 : tensor<2xf32>, %arg1 : tensor<2x3xf32>) -> tensor<3xf32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 2>}
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 2, 3>}
+  // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]])
+  // CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]])
+  %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2xf32>, tensor<2x3xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// CHECK-LABEL: @dot_matrix_vector
+func.func @dot_matrix_vector(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3xf32>) -> tensor<2xf32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 2, 3>}
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 3, 1>}
+  // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]])
+  // CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]])
+  %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @dot_matrix_matrix
+func.func @dot_matrix_matrix(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>) -> tensor<2x4xf32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 2, 3>}
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 3, 4>}
+  // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]])
+  // CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]])
+  %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
+  return %0 : tensor<2x4xf32>
+}
+
+// CHECK-LABEL: @gather
+func.func @gather(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x2xi32>) -> tensor<3x2x5xi32> {
+  // CHECK: tosa.gather
+  %0 = "stablehlo.gather"(%arg0, %arg1) {
+    dimension_numbers = #stablehlo.gather<
+      collapsed_slice_dims = [0],
+      index_vector_dim = 1,
+      offset_dims = [1, 2],
+      start_index_map = [0, 1]
+    >,
+    indices_are_sorted = false,
+    slice_sizes = dense<[1, 2, 5]> : tensor<3xi64>
+  } : (tensor<3x4x5xi32>, tensor<3x2xi32>) -> tensor<3x2x5xi32>
+  return %0 : tensor<3x2x5xi32>
+}
+
+// CHECK-LABEL: @gather_unranked
+func.func @gather_unranked(%arg0 : tensor<*xi32>, %arg1 : tensor<3x2xi32>) -> tensor<*xi32> {
+  // This lowering does not support unranked tensors, so this should not
+  // legalize.
+  // CHECK: stablehlo.gather
+  %0 = "stablehlo.gather"(%arg0, %arg1) {
+    dimension_numbers = #stablehlo.gather<
+      collapsed_slice_dims = [0],
+      index_vector_dim = 1,
+      offset_dims = [1, 2],
+      start_index_map = [0, 1]
+    >,
+    indices_are_sorted = false,
+    slice_sizes = dense<[1, 2, 5]> : tensor<3xi64>
+  } : (tensor<*xi32>, tensor<3x2xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
+// CHECK-LABEL: @maximum
+func.func @maximum(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.maximum
+  %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @maximum_f64
+func.func @maximum_f64(%arg0 : tensor<10xf64>, %arg1 : tensor<10xf64>) -> tensor<10xf64> {
+  // CHECK: stablehlo.maximum
+  %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<10xf64>, tensor<10xf64>) -> tensor<10xf64>
+  return %0 : tensor<10xf64>
+}
+
+// CHECK-LABEL: @minimum
+func.func @minimum(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.minimum
+  %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @multiply
+func.func @multiply(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.mul
+  %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @or
+func.func @or(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> {
+  // CHECK: tosa.bitwise_or
+  %0 = "stablehlo.or"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %0 : tensor<10xi32>
+}
+
+// CHECK-LABEL: @power
+func.func @power(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.pow
+  %0 = "stablehlo.power"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @reduce_max
+func.func @reduce_max(%arg0: tensor<1x10xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
+  // CHECK: tosa.reduce_max
+  // CHECK: tosa.reshape
+  %0 = "stablehlo.reduce"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.maximum %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// CHECK-LABEL: @reduce_sum
+func.func @reduce_sum(%arg0: tensor<5x4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> {
+  // CHECK: tosa.reduce_sum
+  // CHECK: tosa.reshape
+  %0 = "stablehlo.reduce"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xf32>, tensor<f32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: @shift_left
+func.func @shift_left(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> {
+  // CHECK: tosa.logical_left_shift
+  %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %0 : tensor<10xi32>
+}
+
+// CHECK-LABEL: @shift_right_logical
+func.func @shift_right_logical(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> {
+  // CHECK: tosa.logical_right_shift
+  %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %0 : tensor<10xi32>
+}
+
+// CHECK-LABEL: @subtract
+func.func @subtract(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.sub
+  %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @xor
+func.func @xor(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi32> {
+  // CHECK: tosa.bitwise_xor
+  %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %0 : tensor<10xi32>
+}
diff --git a/stablehlo/conversions/tosa/tests/lit.cfg.py b/stablehlo/conversions/tosa/tests/lit.cfg.py
new file mode 100644
index 0000000..c476e6b
--- /dev/null
+++ b/stablehlo/conversions/tosa/tests/lit.cfg.py
@@ -0,0 +1,38 @@
+# Copyright 2022 OpenXLA Authors. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -*- Python -*-
+# pylint: disable=undefined-variable
+
+import os
+
+import lit.formats
+from lit.llvm import llvm_config
+
+# Populate Lit configuration with the minimal required metadata.
+# Some metadata is populated in lit.site.cfg.py.in.
+config.name = 'STABLEHLO_TOSA_OPT'
+config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
+config.suffixes = ['.mlir']
+config.test_source_root = os.path.dirname(__file__)
+
+# Make LLVM and StableHLO tools available in RUN directives
+tools = [
+  'stablehlo-opt',
+  'FileCheck',
+]
+tool_dirs = [
+  config.llvm_tools_dir,
+  config.stablehlo_tools_dir,
+]
+llvm_config.add_tool_substitutions(tools, tool_dirs)
diff --git a/stablehlo/conversions/tosa/tests/lit.site.cfg.py.in b/stablehlo/conversions/tosa/tests/lit.site.cfg.py.in
new file mode 100644
index 0000000..bc65774
--- /dev/null
+++ b/stablehlo/conversions/tosa/tests/lit.site.cfg.py.in
@@ -0,0 +1,20 @@
+# Copyright 2022 OpenXLA Authors. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+@LIT_SITE_CFG_IN_HEADER@
+
+import lit.llvm
+lit.llvm.initialize(lit_config, config)
+config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
+config.stablehlo_tools_dir = "@STABLEHLO_TOOLS_DIR@"
+lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@/stablehlo/conversions/tosa/tests/lit.cfg.py")
diff --git a/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/conversions/tosa/tests/nullary.mlir
new file mode 100644
index 0000000..5896599
--- /dev/null
+++ b/stablehlo/conversions/tosa/tests/nullary.mlir
@@ -0,0 +1,32 @@
+// RUN: stablehlo-opt %s --tosa-legalize-stablehlo | FileCheck %s
+
+// CHECK-LABEL: @constant
+func.func @constant() -> tensor<10xf32> {
+  // CHECK: tosa.const
+  %0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @constant_f64
+func.func @constant_f64() -> tensor<10xf64> {
+  // TOSA does not support 64-bit types, so this should not legalize.
+  // CHECK: stablehlo.constant
+  %0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
+  return %0 : tensor<10xf64>
+}
+
+// CHECK-LABEL: @iota_dimension_0
+func.func @iota_dimension_0() -> tensor<4x8xf32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>}
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.tile"(%[[VAR0]]) {multiples = array<i64: 1, 8>}
+  %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>)
+  return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @iota_dimension_1
+func.func @iota_dimension_1() -> tensor<4x8xi32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi32>}
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.tile"(%[[VAR0]]) {multiples = array<i64: 4, 1>}
+  %0 = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>)
+  return %0 : tensor<4x8xi32>
+}
diff --git a/stablehlo/conversions/tosa/tests/ternary.mlir b/stablehlo/conversions/tosa/tests/ternary.mlir
new file mode 100644
index 0000000..52d6411
--- /dev/null
+++ b/stablehlo/conversions/tosa/tests/ternary.mlir
@@ -0,0 +1,15 @@
+// RUN: stablehlo-opt %s --tosa-legalize-stablehlo | FileCheck %s
+
+// CHECK-LABEL: @concatenate
+func.func @concatenate(%arg0 : tensor<5x2xf32>, %arg1 : tensor<5x5xf32>, %arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
+  // CHECK: "tosa.concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
+  %0 = "stablehlo.concatenate"(%arg0, %arg1, %arg2) {dimension = 1 : i64} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
+  return %0 : tensor<5x14xf32>
+}
+
+// CHECK-LABEL: @select
+func.func @select(%arg0 : tensor<10xi1>, %arg1 : tensor<10xf32>, %arg2 : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.select
+  %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<10xi1>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
diff --git a/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/conversions/tosa/tests/unary.mlir
new file mode 100644
index 0000000..1d70d9b
--- /dev/null
+++ b/stablehlo/conversions/tosa/tests/unary.mlir
@@ -0,0 +1,157 @@
+// RUN: stablehlo-opt %s --tosa-legalize-stablehlo | FileCheck %s
+
+// CHECK-LABEL: @abs
+func.func @abs(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.abs
+  %0 = "stablehlo.abs"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @ceil
+func.func @ceil(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.ceil
+  %0 = "stablehlo.ceil"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @convert
+func.func @convert(%arg : tensor<10xi32>) -> tensor<10xf32> {
+  // CHECK: tosa.cast
+  %0 = "stablehlo.convert"(%arg) : (tensor<10xi32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @exponential
+func.func @exponential(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.exp
+  %0 = "stablehlo.exponential"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @exponential_minus_one
+func.func @exponential_minus_one(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00>
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.exp"(%arg0)
+  // CHECK-DAG: %[[VAR2:.*]] = "tosa.sub"(%[[VAR1]], %[[VAR0]])
+  %0 = "stablehlo.exponential_minus_one"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @floor
+func.func @floor(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.floor
+  %0 = "stablehlo.floor"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @is_finite
+func.func @is_finite(%arg : tensor<10xf32>) -> tensor<10xi1> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0x7F800000>
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.abs"(%arg0)
+  // CHECK-DAG: %[[VAR2:.*]] = "tosa.equal"(%[[VAR1]], %[[VAR0]])
+  // CHECK-DAG: %[[VAR3:.*]] = "tosa.logical_not"(%[[VAR2]])
+  %0 = "stablehlo.is_finite"(%arg) : (tensor<10xf32>) -> tensor<10xi1>
+  return %0 : tensor<10xi1>
+}
+
+// CHECK-LABEL: @log
+func.func @log(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.log
+  %0 = "stablehlo.log"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @log_plus_one
+func.func @log_plus_one(%arg : tensor<10xf16>) -> tensor<10xf16> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00>
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]])
+  // CHECK-DAG: %[[VAR2:.*]] = "tosa.log"(%[[VAR1]])
+  %0 = "stablehlo.log_plus_one"(%arg) : (tensor<10xf16>) -> tensor<10xf16>
+  return %0 : tensor<10xf16>
+}
+
+// CHECK-LABEL: @negate
+func.func @negate(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.negate
+  %0 = "stablehlo.negate"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @slice
+func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> {
+  // CHECK: "tosa.slice"(%arg0) {size = array<i64: 2, 2>, start = array<i64: 2, 1>}
+  %0 = "stablehlo.slice"(%arg) {
+    start_indices = dense<[2, 1]> : tensor<2xi64>,
+    limit_indices = dense<[4, 3]> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } : (tensor<4x3xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// CHECK-LABEL: @slice_stride_not_one
+func.func @slice_stride_not_one(%arg : tensor<4x3xf32>) -> tensor<2x1xf32> {
+  // tosa.slice only supports strides of 1, so this should not legalize.
+  // CHECK: "stablehlo.slice"
+  %0 = "stablehlo.slice"(%arg) {
+    start_indices = dense<[2, 1]> : tensor<2xi64>,
+    limit_indices = dense<[4, 3]> : tensor<2xi64>,
+    strides = dense<[1, 2]> : tensor<2xi64>
+  } : (tensor<4x3xf32>) -> tensor<2x1xf32>
+  return %0 : tensor<2x1xf32>
+}
+
+// CHECK-LABEL: @slice_rank_seven
+func.func @slice_rank_seven(%arg : tensor<2x3x4x5x6x7x8xf32>) -> tensor<1x2x3x4x5x6x7xf32> {
+  // tosa.slice only supports 1D to 6D tensors, so this should not legalize.
+  // CHECK: "stablehlo.slice"
+  %0 = "stablehlo.slice"(%arg) {
+    start_indices = dense<[1, 1, 1, 1, 1, 1, 1]> : tensor<7xi64>,
+    limit_indices = dense<[2, 3, 4, 5, 6, 7, 8]> : tensor<7xi64>,
+    strides = dense<[1, 1, 1, 1, 1, 1, 1]> : tensor<7xi64>
+  } : (tensor<2x3x4x5x6x7x8xf32>) -> tensor<1x2x3x4x5x6x7xf32>
+  return %0 : tensor<1x2x3x4x5x6x7xf32>
+}
+
+// CHECK-LABEL: @tanh
+func.func @tanh(%arg : tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: tosa.tanh
+  %0 = "stablehlo.tanh"(%arg) : (tensor<10xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// CHECK-LABEL: @transpose
+func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> tensor<3xi64>
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.transpose"(%arg0, %[[VAR0]])
+  %0 = "stablehlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
+  return %0 : tensor<3x2x1xf32>
+}
+
+// CHECK-LABEL: @while
+func.func @while(%arg0: tensor<i32>) -> tensor<i32> {
+  // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<3> : tensor<i32>}
+  // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<1> : tensor<i32>}
+  // CHECK:     %[[VAR2:.*]] = "tosa.while_loop"(%arg0) ({
+  // CHECK:     ^bb0(%[[ARG0:.+]]: tensor<i32>):
+  // CHECK:       %[[VAR3:.*]] = "tosa.equal"(%[[ARG0]], %[[VAR0]])
+  // CHECK:       "tosa.yield"(%[[VAR3]])
+  // CHECK:     }, {
+  // CHECK:     ^bb0(%[[ARG0:.+]]: tensor<i32>):
+  // CHECK:       %[[VAR4:.*]] = "tosa.add"(%[[ARG0]], %[[VAR1]])
+  // CHECK:       "tosa.yield"(%[[VAR4]])
+  // CHECK:     }) : (tensor<i32>) -> tensor<i32>
+  // CHECK:     return %[[VAR2]] : tensor<i32>
+  // CHECK:   }
+  %0 = "stablehlo.while"(%arg0) ( {
+  ^bb0(%arg1: tensor<i32>):
+    %1 = "stablehlo.constant"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+    %2 = "stablehlo.compare"(%arg1, %1) {comparison_direction = #stablehlo<comparison_direction EQ>}: (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "stablehlo.return"(%2) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg1: tensor<i32>):
+    %1 = "stablehlo.constant"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %2 = "stablehlo.add"(%arg1, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "stablehlo.return"(%2) : (tensor<i32>) -> ()
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return %0 : tensor<i32>
+}
diff --git a/stablehlo/conversions/tosa/transforms/BUILD.bazel b/stablehlo/conversions/tosa/transforms/BUILD.bazel
new file mode 100644
index 0000000..40fc0e4
--- /dev/null
+++ b/stablehlo/conversions/tosa/transforms/BUILD.bazel
@@ -0,0 +1,72 @@
+# Legalizations and transforms for StableHLO -> TOSA.
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+package_group(
+    name = "internal",
+    packages = [],
+)
+
+gentbl_cc_library(
+    name = "StablehloTOSAPDLLPatternsIncGen",
+    tbl_outs = [
+        (
+            ["-x=cpp"],
+            "legalize_stablehlo.pdll.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-pdll",
+    td_file = "legalize_stablehlo.pdll",
+    deps = [
+        "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:TosaDialectTdFiles",
+        "//:stablehlo_ops_td_files",
+    ],
+)
+
+gentbl_cc_library(
+    name = "StablehloTOSATransformsPassIncGen",
+    strip_include_prefix = ".",
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=StablehloTOSATransforms",
+            ],
+            "passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "passes.td",
+    deps = [
+        "@llvm-project//mlir:PassBaseTdFiles",
+    ],
+)
+
+cc_library(
+    name = "StablehloTOSATransforms",
+    srcs = [
+        "legalize_stablehlo.cc",
+        "prepare_stablehlo.cc",
+    ],
+    hdrs = [
+        "passes.h",
+    ],
+    includes = ["."],
+    deps = [
+        "//:stablehlo_ops",
+        ":StablehloTOSAPDLLPatternsIncGen",
+        ":StablehloTOSATransformsPassIncGen",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:QuantOps",
+        "@llvm-project//mlir:TosaDialect",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
diff --git a/stablehlo/conversions/tosa/transforms/CMakeLists.txt b/stablehlo/conversions/tosa/transforms/CMakeLists.txt
new file mode 100644
index 0000000..347affa
--- /dev/null
+++ b/stablehlo/conversions/tosa/transforms/CMakeLists.txt
@@ -0,0 +1,35 @@
+# Copyright 2022 OpenXLA Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+set(LLVM_TARGET_DEFINITIONS passes.td)
+mlir_tablegen(passes.h.inc -gen-pass-decls -name StablehloTOSATransforms)
+add_public_tablegen_target(StablehloTOSATransformsPassIncGen)
+
+add_mlir_library(StablehloTOSATransforms
+  legalize_stablehlo.cc
+  prepare_stablehlo.cc
+
+  DEPENDS
+  StablehloTOSATransformsPassIncGen
+  StablehloTOSAPDLLPatternsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/stablehlo/conversions/tosa/transforms/legalize_stablehlo.cc b/stablehlo/conversions/tosa/transforms/legalize_stablehlo.cc
new file mode 100644
index 0000000..ff431cd
--- /dev/null
+++ b/stablehlo/conversions/tosa/transforms/legalize_stablehlo.cc
@@ -0,0 +1,502 @@
+/* Copyright 2022 OpenXLA Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "stablehlo/conversions/tosa/transforms/passes.h"
+#include "stablehlo/dialect/StablehloOps.h"
+
+#define GEN_PASS_DEF_TOSALEGALIZESTABLEHLOPASS
+#include "stablehlo/conversions/tosa/transforms/passes.h.inc"
+
+#define PASS_NAME "tosa-legalize-stablehlo"
+#define DEBUG_TYPE PASS_NAME
+
+#include "stablehlo/conversions/tosa/transforms/legalize_stablehlo.pdll.h.inc"
+
+namespace mlir {
+namespace tosa {
+namespace {
+
+struct LegalizeStablehlo
+    : ::impl::TosaLegalizeStablehloPassBase<LegalizeStablehlo> {
+  void runOnOperation() final;
+
+  LogicalResult initialize(MLIRContext* ctx) override;
+
+ private:
+  FrozenRewritePatternSet patterns;
+};
+
+struct ConvertStablehloCompareOp
+    : public OpRewritePattern<stablehlo::CompareOp> {
+  using OpRewritePattern<stablehlo::CompareOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::CompareOp op,
+                                PatternRewriter& rewriter) const override {
+    auto direction = op.getComparisonDirection();
+    auto resultType = op->getResultTypes().front();
+
+    switch (direction) {
+      case stablehlo::ComparisonDirection::EQ: {
+        rewriter.replaceOpWithNewOp<tosa::EqualOp>(op, resultType, op.getLhs(),
+                                                   op.getRhs());
+        break;
+      }
+      case stablehlo::ComparisonDirection::NE: {
+        auto equalOp = rewriter.create<tosa::EqualOp>(op->getLoc(), resultType,
+                                                      op.getLhs(), op.getRhs());
+        rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, resultType,
+                                                        equalOp);
+        break;
+      }
+      default: {
+        return rewriter.notifyMatchFailure(
+            op, "comparison direction not yet implemented");
+      }
+    }
+    return success();
+  }
+};
+
+// TODO(jennik): Move this lowering to PDLL when variadic tensors are supported.
+struct ConvertStablehloConcatenateOp
+    : public OpRewritePattern<stablehlo::ConcatenateOp> {
+  using OpRewritePattern<stablehlo::ConcatenateOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::ConcatenateOp op,
+                                PatternRewriter& rewriter) const override {
+    rewriter.replaceOpWithNewOp<tosa::ConcatOp>(
+        op, op.getResult().getType(), op.getInputs(), op.getDimension());
+    return success();
+  }
+};
+
+struct ConvertStablehloDotOp : public OpRewritePattern<stablehlo::DotOp> {
+  using OpRewritePattern<stablehlo::DotOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::DotOp op,
+                                PatternRewriter& rewriter) const override {
+    auto lhsType = op.getLhs().getType().dyn_cast<RankedTensorType>();
+    auto rhsType = op.getRhs().getType().dyn_cast<RankedTensorType>();
+    if (!lhsType || !rhsType) {
+      return rewriter.notifyMatchFailure(op, "input tensors are not ranked");
+    }
+
+    auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
+    if (!resultType) {
+      return rewriter.notifyMatchFailure(op,
+                                         "result tensor does not have shape");
+    }
+
+    if (lhsType.getElementType() != rhsType.getElementType()) {
+      return rewriter.notifyMatchFailure(
+          op, "lhs and rhs element types must match");
+    }
+
+    auto lhsShape = lhsType.getShape();
+    auto rhsShape = rhsType.getShape();
+    auto resultShape = resultType.getShape();
+    llvm::SmallVector<int64_t, 3> lhsReshape;
+    llvm::SmallVector<int64_t, 3> rhsReshape;
+    llvm::SmallVector<int64_t, 3> matMulShape;
+
+    // tosa.matmul requires input tensors to have a rank of 3, so lhs and rhs
+    // need to be reshaped first.
+    if (lhsType.getRank() == 1) {
+      // Reshape lhs to [1, 1, N].
+      lhsReshape = {1, 1, lhsShape[0]};
+      if (rhsType.getRank() == 1) {
+        // Reshape rhs to [1, N, 1].
+        rhsReshape = {1, rhsShape[0], 1};
+        // MatMul shape is [1, 1, 1].
+        matMulShape = {1, 1, 1};
+      } else if (rhsType.getRank() == 2) {
+        // Reshape rhs to [1, N, K].
+        rhsReshape = {1, rhsShape[0], rhsShape[1]};
+        // MatMul shape is [1, 1, K].
+        matMulShape = {1, 1, rhsShape[1]};
+      } else {
+        return rewriter.notifyMatchFailure(op, "rhs must have rank of 1 or 2");
+      }
+    } else if (lhsType.getRank() == 2) {
+      // Reshape lhs to [1, M, K].
+      lhsReshape = {1, lhsShape[0], lhsShape[1]};
+      if (rhsType.getRank() == 1) {
+        // Reshape rhs to [1, K, 1].
+        rhsReshape = {1, rhsShape[0], 1};
+        // MatMul shape is [1, M, 1].
+        matMulShape = {1, lhsShape[0], 1};
+      } else if (rhsType.getRank() == 2) {
+        // Reshape rhs to [1, K, N].
+        rhsReshape = {1, rhsShape[0], rhsShape[1]};
+        // MatMul shape is [1, M, N].
+        matMulShape = {1, lhsShape[0], rhsShape[1]};
+      } else {
+        return rewriter.notifyMatchFailure(op, "rhs must have rank of 1 or 2");
+      }
+    } else {
+      return rewriter.notifyMatchFailure(op, "lhs must have rank of 1 or 2");
+    }
+
+    auto lhsReshapeType =
+        RankedTensorType::get(lhsReshape, lhsType.getElementType());
+    auto lhsReshapeOp = rewriter.create<tosa::ReshapeOp>(
+        op->getLoc(), lhsReshapeType, op.getLhs(),
+        rewriter.getDenseI64ArrayAttr(lhsReshape));
+
+    auto rhsReshapeType =
+        RankedTensorType::get(rhsReshape, rhsType.getElementType());
+    auto rhsReshapeOp = rewriter.create<tosa::ReshapeOp>(
+        op->getLoc(), rhsReshapeType, op.getRhs(),
+        rewriter.getDenseI64ArrayAttr(rhsReshape));
+
+    auto matMulType =
+        RankedTensorType::get(matMulShape, lhsType.getElementType());
+    auto matMulOp = rewriter.create<tosa::MatMulOp>(op->getLoc(), matMulType,
+                                                    lhsReshapeOp, rhsReshapeOp);
+
+    // Reshape the matmul result back to the original result shape.
+    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+        op, resultType, matMulOp, rewriter.getDenseI64ArrayAttr(resultShape));
+    return success();
+  }
+};
+
+// TODO(jennik): Consider the case of a non-constant expansion.
+struct ConvertStablehloIotaOp : public OpRewritePattern<stablehlo::IotaOp> {
+  using OpRewritePattern<stablehlo::IotaOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::IotaOp op,
+                                PatternRewriter& rewriter) const override {
+    auto resultType = op.getResult().getType();
+    auto elementType = resultType.cast<ShapedType>().getElementType();
+    auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
+
+    if (!resultRankedType) {
+      return rewriter.notifyMatchFailure(op, "result tensor must be ranked");
+    }
+    if (!resultRankedType.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(op, "result tensor must be static");
+    }
+
+    auto resultShape = resultRankedType.getShape();
+    auto iotaDimension = op.getIotaDimension();
+    int64_t iotaArrayLength = resultShape[iotaDimension];
+
+    // Create a const op of [0, 1, 2...iotaArrayLength - 1] to be tiled.
+    llvm::SmallVector<Attribute, 4> constValues;
+    constValues.resize(iotaArrayLength);
+    for (int i = 0; i < iotaArrayLength; i++) {
+      if (elementType.isa<FloatType>()) {
+        constValues[i] = rewriter.getFloatAttr(elementType, i);
+      } else {
+        constValues[i] = rewriter.getIntegerAttr(elementType, i);
+      }
+    }
+
+    RankedTensorType constType =
+        RankedTensorType::get(iotaArrayLength, elementType);
+    auto constOp = rewriter.create<tosa::ConstOp>(
+        op.getLoc(), constType, DenseElementsAttr::get(constType, constValues));
+
+    // Create the multiples attr for the tile op, where all dimensions except
+    // the iota dimension are multiplied.
+    llvm::SmallVector<int64_t, 4> tileMultiples;
+    size_t tileMultiplesSize = resultShape.size();
+    tileMultiples.resize(tileMultiplesSize);
+
+    for (size_t i = 0; i < tileMultiplesSize; i++) {
+      if (i == iotaDimension) {
+        tileMultiples[i] = 1;
+      } else {
+        tileMultiples[i] = resultShape[i];
+      }
+    }
+
+    // Tile the const array to the result shape of the iota op.
+    rewriter.replaceOpWithNewOp<tosa::TileOp>(
+        op, resultType, constOp, rewriter.getDenseI64ArrayAttr(tileMultiples));
+    return success();
+  }
+};
+
+// This legalization supports the case where the Stablehlo start_indices
+// directly map to the TOSA indices.
+struct ConvertStablehloGatherOp : public OpRewritePattern<stablehlo::GatherOp> {
+  using OpRewritePattern<stablehlo::GatherOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::GatherOp op,
+                                PatternRewriter& rewriter) const override {
+    // The input operand must be 3D, with shape [N, K, C].
+    auto operand = op.getOperand();
+    auto operandType = operand.getType().dyn_cast<RankedTensorType>();
+    if (!operandType) {
+      return rewriter.notifyMatchFailure(op, "requires ranked operand shape");
+    }
+    if (operandType.getRank() != 3) {
+      return rewriter.notifyMatchFailure(op, "operand must have rank of 3");
+    }
+
+    // The indices tensor must be 2D, with shape [N, W].
+    auto startIndices = op.getStartIndices();
+    auto startIndicesType = startIndices.getType().dyn_cast<RankedTensorType>();
+    if (!startIndicesType) {
+      return rewriter.notifyMatchFailure(op,
+                                         "requires ranked start_indices shape");
+    }
+    if (startIndicesType.getRank() != 2) {
+      return rewriter.notifyMatchFailure(op,
+                                         "start_indices must have rank of 2");
+    }
+
+    // The result tensor must be 3D, with shape [N, W, C].
+    auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
+    if (!resultType) {
+      return rewriter.notifyMatchFailure(op, "requires ranked output shape");
+    }
+    if (resultType.getRank() != 3) {
+      return rewriter.notifyMatchFailure(op, "result must have rank of 3");
+    }
+
+    auto operandShape = operand.getType().getShape();
+    auto startIndicesShape = startIndices.getType().getShape();
+    auto resultShape = resultType.getShape();
+
+    if (startIndicesShape[0] != resultShape[0] ||
+        startIndicesShape[1] != resultShape[1]) {
+      return rewriter.notifyMatchFailure(op,
+                                         "start_indices and result must have "
+                                         "same number of batches and indices");
+    }
+
+    if (operandShape[0] != resultShape[0] ||
+        operandShape[2] != resultShape[2]) {
+      return rewriter.notifyMatchFailure(op,
+                                         "operand and result must have same "
+                                         "number of batches and data channels");
+    }
+
+    auto startIndexMap = op.getDimensionNumbers().getStartIndexMap();
+    for (const auto& startIndex : llvm::enumerate(startIndexMap)) {
+      if (startIndex.value() != static_cast<int64_t>(startIndex.index())) {
+        return rewriter.notifyMatchFailure(op,
+                                           "start_index_map must be in order");
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<tosa::GatherOp>(op, resultType, operand,
+                                                startIndices);
+    return success();
+  }
+};
+
+struct ConvertStablehloReduceOp : public OpRewritePattern<stablehlo::ReduceOp> {
+  using OpRewritePattern<stablehlo::ReduceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::ReduceOp op,
+                                PatternRewriter& rewriter) const override {
+    Block& bodyBlock = op.getBody().front();
+
+    // To lower to a tosa.reduce_* op, the body should contain the reduce op
+    // and a return op.
+    if (bodyBlock.getOperations().size() != 2) {
+      return rewriter.notifyMatchFailure(op, "body required to contain 2 ops");
+    }
+
+    auto operand = op.getInputs().front();
+    ShapedType inputType = operand.getType().cast<ShapedType>();
+    Operation& innerOp = bodyBlock.front();
+    uint64_t dimension = op.getDimensions().getValues<uint64_t>().begin()[0];
+    SmallVector<int64_t> innerShape(inputType.getShape());
+    innerShape[dimension] = 1;
+    Type innerTy = inputType.clone(innerShape);
+
+    Value reduceOpResult;
+    if (isa<stablehlo::AddOp>(innerOp)) {
+      reduceOpResult =
+          rewriter
+              .create<tosa::ReduceSumOp>(op->getLoc(), innerTy, operand,
+                                         rewriter.getI64IntegerAttr(dimension))
+              .getResult();
+    } else if (isa<stablehlo::MaxOp>(innerOp)) {
+      reduceOpResult =
+          rewriter
+              .create<tosa::ReduceMaxOp>(op->getLoc(), innerTy, operand,
+                                         rewriter.getI64IntegerAttr(dimension))
+              .getResult();
+    } else {
+      return rewriter.notifyMatchFailure(
+          op, "reducing along a " + innerOp.getName().getStringRef().str() +
+                  " op not supported");
+    }
+
+    // TOSA reduce ops do not remove the dimension being reduced, so reshape
+    // the reduced output and remove the reduction dimension.
+    llvm::SmallVector<int64_t, 2> outputShape;
+    int outputShapeLength = innerShape.size() - 1;
+    outputShape.resize(outputShapeLength);
+    for (int64_t i = 0; i < outputShapeLength; i++) {
+      if (i < static_cast<int64_t>(dimension)) {
+        outputShape[i] = innerShape[i];
+      } else {
+        outputShape[i] = innerShape[i + 1];
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+        op, op.getResultTypes().front(), reduceOpResult,
+        rewriter.getDenseI64ArrayAttr(outputShape));
+
+    return success();
+  }
+};
+
+struct ConvertStablehloReturnOp : public OpRewritePattern<stablehlo::ReturnOp> {
+  using OpRewritePattern<stablehlo::ReturnOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::ReturnOp op,
+                                PatternRewriter& rewriter) const override {
+    rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
+                                               op.getResults());
+    return success();
+  }
+};
+
+struct ConvertStablehloSliceOp : public OpRewritePattern<stablehlo::SliceOp> {
+  using OpRewritePattern<stablehlo::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::SliceOp op,
+                                PatternRewriter& rewriter) const override {
+    auto rank = op.getOperand().getType().getRank();
+    if (rank < 1 || rank > 6) {
+      return rewriter.notifyMatchFailure(
+          op, "tosa.slice only supports 1D to 6D tensors");
+    }
+
+    auto strides = op.getStrides().getValues<int64_t>();
+    for (auto stride : strides) {
+      if (stride != 1) {
+        return rewriter.notifyMatchFailure(
+            op, "tosa.slice only supports strides of 1");
+      }
+    }
+
+    auto startIndices = op.getStartIndices().getValues<int64_t>();
+    auto endIndices = op.getLimitIndices().getValues<int64_t>();
+
+    llvm::SmallVector<int64_t, 2> size;
+    size.resize(startIndices.size());
+    llvm::SmallVector<int64_t, 2> startIndicesI64;
+    startIndicesI64.resize(startIndices.size());
+
+    for (int64_t i = 0; i < static_cast<int64_t>(startIndices.size()); i++) {
+      size[i] = endIndices[i] - startIndices[i];
+      startIndicesI64[i] = startIndices[i];
+    }
+
+    rewriter.replaceOpWithNewOp<tosa::SliceOp>(
+        op, op.getResult().getType(), op.getOperand(),
+        rewriter.getDenseI64ArrayAttr(startIndicesI64),
+        rewriter.getDenseI64ArrayAttr(size));
+    return success();
+  }
+};
+
+struct ConvertStablehloTransposeOp
+    : public OpRewritePattern<stablehlo::TransposeOp> {
+  using OpRewritePattern<stablehlo::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::TransposeOp op,
+                                PatternRewriter& rewriter) const override {
+    auto rank = op.getOperand().getType().getRank();
+    if (rank < 1 || rank > 6) {
+      return rewriter.notifyMatchFailure(
+          op, "tosa.transpose only supports 1D to 6D tensors");
+    }
+
+    auto perms = op.getPermutation();
+    auto constOp = rewriter.create<tosa::ConstOp>(
+        op->getLoc(),
+        RankedTensorType::get({perms.size()}, rewriter.getI64Type()), perms);
+    rewriter.replaceOpWithNewOp<tosa::TransposeOp>(op, op.getResult().getType(),
+                                                   op.getOperand(), constOp);
+    return success();
+  }
+};
+
+struct ConvertStablehloWhileOp : public OpRewritePattern<stablehlo::WhileOp> {
+  using OpRewritePattern<stablehlo::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(stablehlo::WhileOp op,
+                                PatternRewriter& rewriter) const override {
+    auto* cond = &op.getCond();
+    auto* body = &op.getBody();
+    auto newWhileOp = rewriter.create<tosa::WhileOp>(
+        op->getLoc(), op->getResultTypes(), op->getOperands());
+
+    auto* newCond = &newWhileOp->getRegion(0);
+    auto* newBody = &newWhileOp->getRegion(1);
+    rewriter.createBlock(newCond);
+    rewriter.createBlock(newBody);
+
+    rewriter.cloneRegionBefore(*cond, &newCond->back());
+    rewriter.eraseBlock(&newCond->back());
+    rewriter.cloneRegionBefore(*body, &newBody->back());
+    rewriter.eraseBlock(&newBody->back());
+
+    rewriter.replaceOp(op, newWhileOp.getResults());
+    return success();
+  }
+};
+
+LogicalResult LegalizeStablehlo::initialize(MLIRContext* ctx) {
+  RewritePatternSet patternList(ctx);
+  populateGeneratedPDLLPatterns(patternList);
+  patternList.addWithLabel<ConvertStablehloCompareOp>({"StablehloCompare"},
+                                                      ctx);
+  patternList.addWithLabel<ConvertStablehloConcatenateOp>(
+      {"StablehloConcatenate"}, ctx);
+  patternList.addWithLabel<ConvertStablehloDotOp>({"StablehloDot"}, ctx);
+  patternList.addWithLabel<ConvertStablehloGatherOp>({"StablehloGather"}, ctx);
+  patternList.addWithLabel<ConvertStablehloIotaOp>({"StablehloIota"}, ctx);
+  patternList.addWithLabel<ConvertStablehloReduceOp>({"StablehloReduce"}, ctx);
+  patternList.addWithLabel<ConvertStablehloReturnOp>({"StablehloReturn"}, ctx);
+  patternList.addWithLabel<ConvertStablehloSliceOp>({"StablehloSlice"}, ctx);
+  patternList.addWithLabel<ConvertStablehloTransposeOp>({"StablehloTranspose"},
+                                                        ctx);
+  patternList.addWithLabel<ConvertStablehloWhileOp>({"StablehloWhile"}, ctx);
+  patterns = std::move(patternList);
+  return success();
+}
+
+void LegalizeStablehlo::runOnOperation() {
+  (void)applyPatternsAndFoldGreedily(getOperation(), patterns);
+}
+
+}  // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeStablehloPass() {
+  return std::make_unique<LegalizeStablehlo>();
+}
+
+}  // namespace tosa
+}  // namespace mlir
diff --git a/stablehlo/conversions/tosa/transforms/legalize_stablehlo.pdll b/stablehlo/conversions/tosa/transforms/legalize_stablehlo.pdll
new file mode 100644
index 0000000..c9a2340
--- /dev/null
+++ b/stablehlo/conversions/tosa/transforms/legalize_stablehlo.pdll
@@ -0,0 +1,165 @@
+// Copyright 2022 OpenXLA Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.td"
+#include "stablehlo/dialect/StablehloOps.td"
+
+// Helper functions.
+Rewrite onesLike(op: Op, type: Type) -> Op [{
+  auto elementType = type.cast<mlir::TensorType>().getElementType();
+  llvm::SmallVector<mlir::Attribute, 4> outputValue;
+
+  if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) {
+    outputValue.push_back(rewriter.getFloatAttr(elementType, 1));
+  } else {
+    outputValue.push_back(rewriter.getIntegerAttr(elementType, 1));
+  }
+
+  return rewriter.create<mlir::tosa::ConstOp>(
+      op->getLoc(), type,
+      mlir::DenseElementsAttr::get(
+        llvm::cast<mlir::ShapedType>(type), outputValue));
+}];
+
+Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{
+  auto elementType = type.cast<mlir::TensorType>().getElementType();
+  const llvm::fltSemantics& semantic =
+      elementType.cast<mlir::FloatType>().getFloatSemantics();
+
+  llvm::SmallVector<mlir::Attribute, 4> outputValue;
+  outputValue.push_back(rewriter.getFloatAttr(
+    elementType, llvm::APFloat::getInf(semantic, false)));
+
+  return rewriter.create<mlir::tosa::ConstOp>(
+      op->getLoc(), type,
+      mlir::DenseElementsAttr::get(
+        llvm::cast<mlir::ShapedType>(type), outputValue));
+}];
+
+// Nullary ops.
+Pattern =>
+  replace op<stablehlo.constant> {value = input: Attr<_: Tosa_Tensor>}
+     with op<tosa.const> {value = input};
+
+// Unary ops.
+Pattern =>
+  replace op<stablehlo.abs>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.abs>(input);
+Pattern =>
+  replace op<stablehlo.ceil>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.ceil>(input);
+Pattern =>
+  replace op<stablehlo.convert>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.cast>(input);
+Pattern =>
+  replace op<stablehlo.exponential>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.exp>(input);
+Pattern {
+  let root = op<stablehlo.exponential_minus_one>
+                (input : Value<inputType: Tosa_Tensor>);
+  rewrite root with {
+    let ones = onesLike(root, inputType);
+    let expResult = op<tosa.exp>(input) -> (inputType);
+    let expMinusOneResult = op<tosa.sub>(expResult, ones) -> (inputType);
+    replace root with expMinusOneResult;
+  };
+}
+Pattern =>
+  replace op<stablehlo.floor>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.floor>(input);
+Pattern {
+  let root = op<stablehlo.is_finite>(input : Value<inputType: Tosa_Tensor>);
+  rewrite root with {
+    let positiveInfinity = positiveFloatInfinityLike(root, inputType);
+    let inputAbs = op<tosa.abs>(input) -> (inputType);
+    let equalsResult = op<tosa.equal>(positiveInfinity, inputAbs);
+    let notEqualsResult = op<tosa.logical_not>(equalsResult);
+    replace root with notEqualsResult;
+  };
+}
+Pattern =>
+  replace op<stablehlo.log>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.log>(input);
+Pattern {
+  let root = op<stablehlo.log_plus_one>(input : Value<inputType: Tosa_Tensor>);
+  rewrite root with {
+    let ones = onesLike(root, inputType);
+    let addResult = op<tosa.add>(input, ones) -> (inputType);
+    let logPlusOneResult = op<tosa.log>(addResult) -> (inputType);
+    replace root with logPlusOneResult;
+  };
+}
+Pattern =>
+  replace op<stablehlo.negate>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.negate>(input);
+Pattern =>
+  replace op<stablehlo.tanh>(input : Value<_: Tosa_Tensor>)
+     with op<tosa.tanh>(input);
+
+// Binary ops.
+Pattern =>
+  replace op<stablehlo.add>(input0 : Value<_: Tosa_Tensor>,
+                       input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.add>(input0, input1);
+Pattern =>
+  replace op<stablehlo.and>(input0 : Value<_: Tosa_Tensor>,
+                       input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.bitwise_and>(input0, input1);
+Pattern =>
+  replace op<stablehlo.divide>(input0 : Value<_: Tosa_Int32Tensor>,
+                          input1 : Value<_: Tosa_Int32Tensor>)
+     with op<tosa.div>(input0, input1);
+Pattern =>
+  replace op<stablehlo.maximum>(input0 : Value<_: Tosa_Tensor>,
+                           input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.maximum>(input0, input1);
+Pattern =>
+  replace op<stablehlo.minimum>(input0 : Value<_: Tosa_Tensor>,
+                           input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.minimum>(input0, input1);
+Pattern =>
+  replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
+                            input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.mul>(input0, input1) {shift = attr<"0 : i32">};
+Pattern =>
+  replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
+                      input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.bitwise_or>(input0, input1);
+Pattern =>
+  replace op<stablehlo.power>(input0 : Value<_: Tosa_Tensor>,
+                         input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.pow>(input0, input1);
+Pattern =>
+  replace op<stablehlo.shift_left>(input0 : Value<_: Tosa_Tensor>,
+                              input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.logical_left_shift>(input0, input1);
+Pattern =>
+  replace op<stablehlo.shift_right_logical>(input0 : Value<_: Tosa_Tensor>,
+                                       input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.logical_right_shift>(input0, input1);
+Pattern =>
+  replace op<stablehlo.subtract>(input0 : Value<_: Tosa_Tensor>,
+                            input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.sub>(input0, input1);
+Pattern =>
+  replace op<stablehlo.xor>(input0 : Value<_: Tosa_Tensor>,
+                       input1 : Value<_: Tosa_Tensor>)
+     with op<tosa.bitwise_xor>(input0, input1);
+
+// Ternary ops.
+Pattern =>
+  replace op<stablehlo.select>(input0 : Value<_: Tosa_Tensor>,
+                          input1 : Value<_: Tosa_Tensor>,
+                          input2 : Value<_: Tosa_Tensor>)
+     with op<tosa.select>(input0, input1, input2);
diff --git a/stablehlo/conversions/tosa/transforms/passes.h b/stablehlo/conversions/tosa/transforms/passes.h
new file mode 100644
index 0000000..52010a9
--- /dev/null
+++ b/stablehlo/conversions/tosa/transforms/passes.h
@@ -0,0 +1,37 @@
+/* Copyright 2022 OpenXLA Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef STABLEHLO_CONVERSIONS_TOSA_TRANSFORMS_PASSES_H
+#define STABLEHLO_CONVERSIONS_TOSA_TRANSFORMS_PASSES_H
+
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeStablehloPass();
+std::unique_ptr<OperationPass<func::FuncOp>> createPrepareStablehloPass();
+
+#define GEN_PASS_REGISTRATION
+#define GEN_PASS_DECL_TOSALEGALIZESTABLEHLOPASS
+#include "stablehlo/conversions/tosa/transforms/passes.h.inc"
+
+}  // namespace tosa
+}  // namespace mlir
+
+#endif  // STABLEHLO_CONVERSIONS_TOSA_TRANSFORMS_PASSES_H
diff --git a/stablehlo/conversions/tosa/transforms/passes.td b/stablehlo/conversions/tosa/transforms/passes.td
new file mode 100644
index 0000000..b38892b
--- /dev/null
+++ b/stablehlo/conversions/tosa/transforms/passes.td
@@ -0,0 +1,32 @@
+/* Copyright 2022 OpenXLA Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+include "mlir/Pass/PassBase.td"
+
+def TosaLegalizeStablehloPass : Pass<"tosa-legalize-stablehlo", "mlir::func::FuncOp"> {
+  let summary = "Legalize from Stablehlo to TOSA";
+  let constructor = "createLegalizeStablehloPass()";
+  let dependentDialects = ["::mlir::tosa::TosaDialect"];
+}
+
+def TosaPrepareStablehloPass : Pass<"tosa-prepare-stablehlo", "mlir::func::FuncOp"> {
+  let summary = "Prepare Stablehlo for lowering to TOSA";
+  let description = [{
+    This pass adds rewriters to make Stablehlo ops more compatible with TOSA ops.
+    Currently simplifies stablehlo.dot_general into stablehlo.dot for easier lowering.
+  }];
+  let constructor = "createPrepareStablehloPass()";
+  let dependentDialects = ["::mlir::tosa::TosaDialect"];
+}
diff --git a/stablehlo/conversions/tosa/transforms/prepare_stablehlo.cc b/stablehlo/conversions/tosa/transforms/prepare_stablehlo.cc
new file mode 100644
index 0000000..8862411
--- /dev/null
+++ b/stablehlo/conversions/tosa/transforms/prepare_stablehlo.cc
@@ -0,0 +1,59 @@
+/* Copyright 2022 OpenXLA Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "stablehlo/conversions/tosa/transforms/passes.h"
+#include "stablehlo/dialect/StablehloOps.h"
+
+#define GEN_PASS_DEF_TOSAPREPARESTABLEHLOPASS
+#include "stablehlo/conversions/tosa/transforms/passes.h.inc"
+
+#define PASS_NAME "tosa-prepare-stablehlo"
+#define DEBUG_TYPE PASS_NAME
+
+namespace mlir {
+namespace tosa {
+namespace {
+
+class PrepareStablehlo
+    : public ::impl::TosaPrepareStablehloPassBase<PrepareStablehlo> {
+ public:
+  explicit PrepareStablehlo() = default;
+  void runOnOperation() override;
+};
+
+void PrepareStablehlo::runOnOperation() {
+  auto* ctx = &getContext();
+  RewritePatternSet patterns(ctx);
+  // Currently these equivalents are not available here.
+  // TODO: Enable post upstreaming decision.
+  // stablehlo::DotGeneralOp::getCanonicalizationPatterns(patterns, ctx);
+  // stablehlo::populateGeneralDotOpLoweringPatterns(&patterns, ctx);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+}  // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>> createPrepareStablehloPass() {
+  return std::make_unique<PrepareStablehlo>();
+}
+
+}  // namespace tosa
+}  // namespace mlir
diff --git a/stablehlo/tools/CMakeLists.txt b/stablehlo/tools/CMakeLists.txt
index 61850d8..d7bf7e2 100644
--- a/stablehlo/tools/CMakeLists.txt
+++ b/stablehlo/tools/CMakeLists.txt
@@ -29,6 +29,7 @@
         StablehloRegister
         StablehloTestUtils
         StablehloPasses
+        StablehloTOSATransforms
         )
 add_llvm_executable(stablehlo-opt StablehloOptMain.cpp)
 llvm_update_compile_flags(stablehlo-opt)
diff --git a/stablehlo/tools/StablehloOptMain.cpp b/stablehlo/tools/StablehloOptMain.cpp
index e510926..b4f303c 100644
--- a/stablehlo/tools/StablehloOptMain.cpp
+++ b/stablehlo/tools/StablehloOptMain.cpp
@@ -13,9 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
 #include "mlir/InitAllDialects.h"
 #include "mlir/InitAllPasses.h"
 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "stablehlo/conversions/tosa/transforms/passes.h"
 #include "stablehlo/dialect/Register.h"
 #include "stablehlo/tests/TestUtils.h"
 #include "stablehlo/transforms/Passes.h"
@@ -24,6 +27,8 @@
   mlir::registerAllPasses();
   mlir::hlo::registerAllTestPasses();
   mlir::stablehlo::registerPasses();
+  mlir::tosa::registerTosaLegalizeStablehloPassPass();
+  mlir::tosa::registerTosaPrepareStablehloPassPass();
 
   mlir::DialectRegistry registry;
   mlir::registerAllDialects(registry);