Add option to lower a TensorFlow model to the StableHLO Dialect (#358)

diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml
index 6387dc9..12386e3 100644
--- a/.github/workflows/build-and-test.yaml
+++ b/.github/workflows/build-and-test.yaml
@@ -97,9 +97,10 @@
         python3 ../${EMITC}/scripts/model_to_savedmodel_with_predict_function.py --batch-size 2 mobilenet_v2.h5 model
         python3 ../${EMITC}/scripts/savedmodel_to_tf_dialect.py --exported-names predict model model_tf.mlir
         python3 ../${EMITC}/scripts/optimize_tf_dialect.py model_tf.mlir model_tf_opt.mlir
-        python3 ../${EMITC}/scripts/tf_to_mhlo_dialect.py model_tf_opt.mlir model_mhlo.mlir
+        python3 ../${EMITC}/scripts/tf_to_hlo_dialect.py --hlo-dialect mhlo model_tf_opt.mlir model_mhlo.mlir
         sed "s/tf._input_shapes =.*]//" model_mhlo.mlir > ../${E2E}/model_mhlo_noattr.mlir
         sed -i "s/, }/}/" ../${E2E}/model_mhlo_noattr.mlir
+        python3 ../${EMITC}/scripts/tf_to_hlo_dialect.py --hlo-dialect stablehlo model_tf_opt.mlir ../${E2E}/model_stablehlo.mlir
         python3 ../${EMITC}/scripts/tf_to_tosa_dialect.py model_tf_opt.mlir model_tosa.mlir
         sed "s/tf._input_shapes =.*]//" model_tosa.mlir > ../${E2E}/model_tosa_noattr.mlir
         sed -i "s/, }/}/" ../${E2E}/model_tosa_noattr.mlir
@@ -227,6 +228,22 @@
         ../${EMITC}/build_release/bin/emitc-translate --mlir-to-cpp model_emitc.mlir > model_generated.h
         clang++ ../${E2E}/test.cpp -O3 -I `pwd`/../emitc/reference-implementation/include -I `pwd` -o test
         ./test
+    
+    # - name: Run StableHLO e2e test
+    #   run: |
+    #     mkdir tmp-stablehlo
+    #     cd tmp-stablehlo
+    #     ../${EMITC}/build_release/bin/emitc-opt --canonicalize --inline --symbol-dce ../${E2E}/model_stablehlo.mlir > model_canon.mlir
+    #     FUNCTION_NAME=$(grep -oe "@[^(]*" model_canon.mlir)
+    #     FUNCTION_NAME="${FUNCTION_NAME:1}"
+    #     sed "s/$FUNCTION_NAME/predict/g" model_canon.mlir > model_fix_name.mlir
+    #     ../${EMITC}/build_release/bin/emitc-opt \
+    #       --insert-emitc-mhlo-include \
+    #       --convert-stablehlo-to-emitc \
+    #       model_fix_name.mlir > model_emitc.mlir
+    #     ../${EMITC}/build_release/bin/emitc-translate --mlir-to-cpp model_emitc.mlir > model_generated.h
+    #     clang++ ../${E2E}/test.cpp -O3 -I `pwd`/../emitc/reference-implementation/include -I `pwd` -o test
+    #     ./test
 
     - name: Run TOSA e2e test
       run: |
diff --git a/scripts/e2e_test.sh b/scripts/e2e_test.sh
index f873abf..b87a5c4 100755
--- a/scripts/e2e_test.sh
+++ b/scripts/e2e_test.sh
@@ -58,7 +58,7 @@
 python optimize_tf_dialect.py "$OUTPUT_DIR"/model_tf.mlir "$OUTPUT_DIR"/model_tf_opt.mlir
 
 echo "Converting tf dialect to mhlo dialect"
-python tf_to_mhlo_dialect.py "$OUTPUT_DIR"/model_tf_opt.mlir "$OUTPUT_DIR"/model_mhlo.mlir
+python tf_to_hlo_dialect.py --hlo-dialect mhlo "$OUTPUT_DIR"/model_tf_opt.mlir "$OUTPUT_DIR"/model_mhlo.mlir
 
 echo "Removing tf._input_shapes attribute"
 sed "s/tf._input_shapes =.*]//" "$OUTPUT_DIR"/model_mhlo.mlir > "$OUTPUT_DIR"/model_mhlo_noattr.mlir
diff --git a/scripts/tf_to_mhlo_dialect.py b/scripts/tf_to_hlo_dialect.py
similarity index 73%
rename from scripts/tf_to_mhlo_dialect.py
rename to scripts/tf_to_hlo_dialect.py
index 70df76c..cd7e066 100644
--- a/scripts/tf_to_mhlo_dialect.py
+++ b/scripts/tf_to_hlo_dialect.py
@@ -19,11 +19,12 @@
 from tensorflow.python import pywrap_mlir  # pylint: disable=no-name-in-module
 
 
-def convert(model_path: str, output_path: str):
-    pass_pipeline = ",".join([
-        "func.func(xla-legalize-tf)", "canonicalize",
-        "tf-saved-model-optimize-global-tensors"
-    ])
+def convert(model_path: str, output_path: str, hlo_dialect: str):
+    pass_pipeline = ["tf-lower-to-mlprogram-and-hlo"]
+    if hlo_dialect == "mhlo":
+        pass_pipeline.append("stablehlo-legalize-to-hlo")
+    pass_pipeline = ",".join(pass_pipeline)
+    
     with open(model_path) as file:
         mlir = file.read()
 
@@ -36,6 +37,13 @@
 def main():
     parser = argparse.ArgumentParser(
         description="Convert model in tf dialect to mhlo dialect")
+    parser.add_argument(
+        "--hlo-dialect",
+        type=str,
+        choices=["mhlo", "stablehlo"],
+        default="mhlo",
+        help="Which flavor of HLO dialect to export",
+    )
     parser.add_argument("model_path",
                         metavar="model-path",
                         help="Path to tf mlir model")
@@ -44,7 +52,7 @@
                         help="Output path")
     args = parser.parse_args()
 
-    convert(args.model_path, args.output_path)
+    convert(args.model_path, args.output_path, args.hlo_dialect)
 
 
 if __name__ == "__main__":