Add option to lower a TensorFlow model to the StableHLO Dialect (#358)
diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml
index 6387dc9..12386e3 100644
--- a/.github/workflows/build-and-test.yaml
+++ b/.github/workflows/build-and-test.yaml
@@ -97,9 +97,10 @@
python3 ../${EMITC}/scripts/model_to_savedmodel_with_predict_function.py --batch-size 2 mobilenet_v2.h5 model
python3 ../${EMITC}/scripts/savedmodel_to_tf_dialect.py --exported-names predict model model_tf.mlir
python3 ../${EMITC}/scripts/optimize_tf_dialect.py model_tf.mlir model_tf_opt.mlir
- python3 ../${EMITC}/scripts/tf_to_mhlo_dialect.py model_tf_opt.mlir model_mhlo.mlir
+ python3 ../${EMITC}/scripts/tf_to_hlo_dialect.py --hlo-dialect mhlo model_tf_opt.mlir model_mhlo.mlir
sed "s/tf._input_shapes =.*]//" model_mhlo.mlir > ../${E2E}/model_mhlo_noattr.mlir
sed -i "s/, }/}/" ../${E2E}/model_mhlo_noattr.mlir
+ python3 ../${EMITC}/scripts/tf_to_hlo_dialect.py --hlo-dialect stablehlo model_tf_opt.mlir ../${E2E}/model_stablehlo.mlir
python3 ../${EMITC}/scripts/tf_to_tosa_dialect.py model_tf_opt.mlir model_tosa.mlir
sed "s/tf._input_shapes =.*]//" model_tosa.mlir > ../${E2E}/model_tosa_noattr.mlir
sed -i "s/, }/}/" ../${E2E}/model_tosa_noattr.mlir
@@ -227,6 +228,22 @@
../${EMITC}/build_release/bin/emitc-translate --mlir-to-cpp model_emitc.mlir > model_generated.h
clang++ ../${E2E}/test.cpp -O3 -I `pwd`/../emitc/reference-implementation/include -I `pwd` -o test
./test
+
+ # - name: Run StableHLO e2e test
+ # run: |
+ # mkdir tmp-stablehlo
+ # cd tmp-stablehlo
+ # ../${EMITC}/build_release/bin/emitc-opt --canonicalize --inline --symbol-dce ../${E2E}/model_stablehlo.mlir > model_canon.mlir
+ # FUNCTION_NAME=$(grep -oe "@[^(]*" model_canon.mlir)
+ # FUNCTION_NAME="${FUNCTION_NAME:1}"
+ # sed "s/$FUNCTION_NAME/predict/g" model_canon.mlir > model_fix_name.mlir
+ # ../${EMITC}/build_release/bin/emitc-opt \
+ # --insert-emitc-mhlo-include \
+ # --convert-stablehlo-to-emitc \
+ # model_fix_name.mlir > model_emitc.mlir
+ # ../${EMITC}/build_release/bin/emitc-translate --mlir-to-cpp model_emitc.mlir > model_generated.h
+ # clang++ ../${E2E}/test.cpp -O3 -I `pwd`/../emitc/reference-implementation/include -I `pwd` -o test
+ # ./test
- name: Run TOSA e2e test
run: |
diff --git a/scripts/e2e_test.sh b/scripts/e2e_test.sh
index f873abf..b87a5c4 100755
--- a/scripts/e2e_test.sh
+++ b/scripts/e2e_test.sh
@@ -58,7 +58,7 @@
python optimize_tf_dialect.py "$OUTPUT_DIR"/model_tf.mlir "$OUTPUT_DIR"/model_tf_opt.mlir
echo "Converting tf dialect to mhlo dialect"
-python tf_to_mhlo_dialect.py "$OUTPUT_DIR"/model_tf_opt.mlir "$OUTPUT_DIR"/model_mhlo.mlir
+python tf_to_hlo_dialect.py --hlo-dialect mhlo "$OUTPUT_DIR"/model_tf_opt.mlir "$OUTPUT_DIR"/model_mhlo.mlir
echo "Removing tf._input_shapes attribute"
sed "s/tf._input_shapes =.*]//" "$OUTPUT_DIR"/model_mhlo.mlir > "$OUTPUT_DIR"/model_mhlo_noattr.mlir
diff --git a/scripts/tf_to_mhlo_dialect.py b/scripts/tf_to_hlo_dialect.py
similarity index 73%
rename from scripts/tf_to_mhlo_dialect.py
rename to scripts/tf_to_hlo_dialect.py
index 70df76c..cd7e066 100644
--- a/scripts/tf_to_mhlo_dialect.py
+++ b/scripts/tf_to_hlo_dialect.py
@@ -19,11 +19,12 @@
from tensorflow.python import pywrap_mlir # pylint: disable=no-name-in-module
-def convert(model_path: str, output_path: str):
- pass_pipeline = ",".join([
- "func.func(xla-legalize-tf)", "canonicalize",
- "tf-saved-model-optimize-global-tensors"
- ])
+def convert(model_path: str, output_path: str, hlo_dialect: str):
+ pass_pipeline = ["tf-lower-to-mlprogram-and-hlo"]
+ if hlo_dialect == "mhlo":
+ pass_pipeline.append("stablehlo-legalize-to-hlo")
+ pass_pipeline = ",".join(pass_pipeline)
+
with open(model_path) as file:
mlir = file.read()
@@ -36,6 +37,13 @@
def main():
parser = argparse.ArgumentParser(
description="Convert model in tf dialect to mhlo dialect")
+ parser.add_argument(
+ "--hlo-dialect",
+ type=str,
+ choices=["mhlo", "stablehlo"],
+ default="mhlo",
+ help="Which flavor of HLO dialect to export",
+ )
parser.add_argument("model_path",
metavar="model-path",
help="Path to tf mlir model")
@@ -44,7 +52,7 @@
help="Output path")
args = parser.parse_args()
- convert(args.model_path, args.output_path)
+ convert(args.model_path, args.output_path, args.hlo_dialect)
if __name__ == "__main__":