blob: ae904e99edd1caa9b3feecb889db29a9f82aed67 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
#define TENSORFLOW_LITE_SIGNATURE_RUNNER_H_
#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/internal/signature_def.h"
namespace tflite {
namespace impl {
class Interpreter; // Class for friend declarations.
}
class SignatureRunnerJNIHelper; // Class for friend declarations.
class TensorHandle; // Class for friend declarations.
class SignatureRunnerHelper; // Class for friend declarations.
/// WARNING: Experimental interface, subject to change
///
/// SignatureRunner class for running TFLite models using SignatureDef.
///
/// Usage:
///
/// <pre><code>
/// // Create model from file. Note that the model instance must outlive the
/// // interpreter instance.
/// auto model = tflite::FlatBufferModel::BuildFromFile(...);
/// if (model == nullptr) {
/// // Return error.
/// }
///
/// // Create an Interpreter with an InterpreterBuilder.
/// std::unique_ptr<tflite::Interpreter> interpreter;
/// tflite::ops::builtin::BuiltinOpResolver resolver;
/// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
/// // Return failure.
/// }
///
/// // Get the list of signatures and check it.
/// auto signature_defs = interpreter->signature_def_names();
/// if (signature_defs.empty()) {
/// // Return error.
/// }
///
/// // Get pointer to the SignatureRunner instance corresponding to a signature.
/// // Note that the pointed SignatureRunner instance has lifetime same as the
/// // Interpreter instance.
/// tflite::SignatureRunner* runner =
/// interpreter->GetSignatureRunner(signature_defs[0]->c_str());
/// if (runner == nullptr) {
/// // Return error.
/// }
/// if (runner->AllocateTensors() != kTfLiteOk) {
/// // Return failure.
/// }
///
/// // Set input data. In this example, the input tensor has float type.
/// float* input = runner->input_tensor(0)->data.f;
/// for (int i = 0; i < input_size; i++) {
/// input[i] = ...;
// }
/// runner->Invoke();
/// </code></pre>
///
/// WARNING: This class is *not* thread-safe. The client is responsible for
/// ensuring serialized interaction to avoid data races and undefined behavior.
///
/// SignatureRunner and Interpreter share the same underlying data. Calling
/// methods on an Interpreter object will affect the state in corresponding
/// SignatureRunner objects. Therefore, it is recommended not to call other
/// Interpreter methods after calling GetSignatureRunner to create
/// SignatureRunner instances.
class SignatureRunner {
public:
/// Returns the key for the corresponding signature.
const std::string& signature_key() { return signature_def_->signature_key; }
/// Returns the number of inputs.
size_t input_size() const { return subgraph_->inputs().size(); }
/// Returns the number of outputs.
size_t output_size() const { return subgraph_->outputs().size(); }
/// Read-only access to list of signature input names.
const std::vector<const char*>& input_names() { return input_names_; }
/// Read-only access to list of signature output names.
const std::vector<const char*>& output_names() { return output_names_; }
/// Returns the input tensor identified by 'input_name' in the
/// given signature. Returns nullptr if the given name is not valid.
TfLiteTensor* input_tensor(const char* input_name);
/// Returns the output tensor identified by 'output_name' in the
/// given signature. Returns nullptr if the given name is not valid.
const TfLiteTensor* output_tensor(const char* output_name) const;
/// Change a dimensionality of a given tensor. Note, this is only acceptable
/// for tensors that are inputs.
/// Returns status of failure or success. Note that this doesn't actually
/// resize any existing buffers. A call to AllocateTensors() is required to
/// change the tensor input buffer.
TfLiteStatus ResizeInputTensor(const char* input_name,
const std::vector<int>& new_size);
/// Change the dimensionality of a given tensor. This is only acceptable for
/// tensor indices that are inputs or variables. Only unknown dimensions can
/// be resized with this function. Unknown dimensions are indicated as `-1` in
/// the `dims_signature` attribute of a TfLiteTensor.
/// Returns status of failure or success. Note that this doesn't actually
/// resize any existing buffers. A call to AllocateTensors() is required to
/// change the tensor input buffer.
TfLiteStatus ResizeInputTensorStrict(const char* input_name,
const std::vector<int>& new_size);
/// Updates allocations for all tensors, related to the given signature.
TfLiteStatus AllocateTensors() { return subgraph_->AllocateTensors(); }
/// Invokes the signature runner (run the graph identified by the given
/// signature in dependency order).
TfLiteStatus Invoke();
/// Attempts to cancel in flight invocation if any.
/// This will not affect calls to `Invoke` that happend after this.
/// Non blocking and thread safe.
/// Returns kTfLiteError if cancellation is not enabled, otherwise returns
/// kTfLiteOk.
/// WARNING: This is an experimental API and subject to change.
TfLiteStatus Cancel() { return subgraph_->Cancel(); }
private:
// The life cycle of SignatureRunner depends on the life cycle of Subgraph,
// which is owned by an Interpreter. Therefore, the Interpreter will takes the
// responsibility to create and manage SignatureRunner objects to make sure
// SignatureRunner objects don't outlive their corresponding Subgraph objects.
SignatureRunner(const internal::SignatureDef* signature_def,
Subgraph* subgraph);
friend class ::tflite::impl::Interpreter;
friend class SignatureRunnerJNIHelper;
friend class TensorHandle;
friend class SignatureRunnerHelper;
// The SignatureDef object is owned by the interpreter.
const internal::SignatureDef* signature_def_;
// The Subgraph object is owned by the interpreter.
Subgraph* subgraph_;
// The list of input tensor names.
std::vector<const char*> input_names_;
// The list of output tensor names.
std::vector<const char*> output_names_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_SIGNATURE_RUNNER_H_