blob: 6decc3578038c9ee101a5660c862db295e15843d [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for converting between TFLite and C++."""
from absl import app
from absl import flags
from absl import logging
from compiler_opt.tools import tflite_to_cpp_lib
flags.DEFINE_string('input', None,
'Input, which should be a path to a tflite model')
flags.mark_flag_as_required('input')
flags.DEFINE_string('output_dir', None,
'Output directory for the generated files')
flags.mark_flag_as_required('output_dir')
flags.DEFINE_string(
'name',
None,
('Name to use for the model. This will be in the filenames and also will'
' be used to identify the model within LLVM. This should be unique'
' between models'),
)
flags.mark_flag_as_required('name')
flags.DEFINE_string(
'base_class',
None,
('Base class to use for the generated model. This is used when'
' registering the model in LLVM. This should be a fully-qualified name,'
' e.g. ::llvm::MLInlineOzEmitCModel'),
)
flags.mark_flag_as_required('base_class')
flags.DEFINE_multi_string(
'additional_headers',
None,
('Additional headers to include for the model, for instance the header'
' definining the base class. Should be of the form'
' --additional_headers="llvm/Analysis/MyHeader.h"'),
)
flags.DEFINE_string(
'iree_import_tflite_path',
None,
'Path to the iree-import-tflite binary from iree repository',
)
flags.mark_flag_as_required('iree_import_tflite_path')
flags.DEFINE_string(
'emitc_opt_path',
None,
'Path to the emitc-opt binary from the emitc repository',
)
flags.mark_flag_as_required('emitc_opt_path')
flags.DEFINE_string(
'mlir_translate_path',
None,
'Path to the mlir-translate binary from the llvm repository',
)
flags.mark_flag_as_required('mlir_translate_path')
flags.DEFINE_string(
'emitc_runtime_path',
None,
'Path to the emitc runtime to embed in the generated c++ model',
)
flags.mark_flag_as_required('emitc_runtime_path')
flags.DEFINE_string(
'clang_format_path',
None,
('(Optional) path to clang-format binary to use to format the resulting'
' files'),
)
flags.DEFINE_string(
'clang_format_style',
'llvm',
'Style argument to use for clang format',
)
FLAGS = flags.FLAGS
def main(argv):
del argv
logging.info('Beginning conversion pipeline.')
tosa = tflite_to_cpp_lib.tflite_to_tosa(
tflite_path=FLAGS.input,
iree_import_tflite_path=FLAGS.iree_import_tflite_path,
)
emitc_mlir = tflite_to_cpp_lib.tosa_to_emitc_mlir(
tosa=tosa, emitc_opt_path=FLAGS.emitc_opt_path)
model = tflite_to_cpp_lib.emitc_mlir_to_cpp(
emitc_mlir=emitc_mlir,
mlir_translate_path=FLAGS.mlir_translate_path,
name=FLAGS.name,
base_class=FLAGS.base_class,
)
model = tflite_to_cpp_lib.embed_runtime(
model=model,
runtime_path=FLAGS.emitc_runtime_path,
)
model = tflite_to_cpp_lib.add_additional_headers(
model=model, additional_headers=FLAGS.additional_headers)
tflite_to_cpp_lib.print_llvm_registration_handle(
model=model, base_class=FLAGS.base_class)
model = tflite_to_cpp_lib.add_license_and_notice(model=model)
if FLAGS.clang_format_path:
model = tflite_to_cpp_lib.format_model(
model=model,
clang_format_path=FLAGS.clang_format_path,
clang_format_style=FLAGS.clang_format_style,
)
cpp_path = tflite_to_cpp_lib.get_model_cpp_path(model, FLAGS.output_dir)
hdr_path = tflite_to_cpp_lib.get_model_hdr_path(model, FLAGS.output_dir)
logging.info('Writing generated files to [%s] and [%s].', cpp_path, hdr_path)
with open(cpp_path, 'wt', encoding='utf-8') as f:
f.write(model.cpp)
with open(hdr_path, 'wt', encoding='utf-8') as f:
f.write(model.hdr)
logging.info('Done.')
if __name__ == '__main__':
app.run(main)