blob: 19e6d626e2386cffb2e618f796fc017d632c9c40 [file] [log] [blame]
//===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for interfacing with tensorflow C APIs.
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/Utils/TFUtils.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include <cassert>
using namespace llvm;
namespace {
using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
using TFSessionOptionsPtr =
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
struct TFInitializer {
TFInitializer() {
assert(!IsInitialized && "TFInitialized should be called only once");
int Argc = 1;
const char *Name = "";
const char **NamePtr = &Name;
TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
IsInitialized = true;
}
bool IsInitialized = false;
};
llvm::ManagedStatic<TFInitializer> TFLibInitializer;
bool ensureInitTF() { return TFLibInitializer->IsInitialized; }
TFGraphPtr createTFGraph() {
return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
}
TFStatusPtr createTFStatus() {
return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
}
TFSessionOptionsPtr createTFSessionOptions() {
return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
}
} // namespace
namespace llvm {
class EvaluationResultImpl {
public:
EvaluationResultImpl(size_t OutputSize)
: OutputSize(OutputSize), Output(OutputSize){};
~EvaluationResultImpl() {
for (auto *P : Output)
if (P)
TF_DeleteTensor(P);
}
EvaluationResultImpl(const EvaluationResultImpl &) = delete;
EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
std::vector<TF_Tensor *> &getOutput() { return Output; }
private:
const size_t OutputSize;
std::vector<TF_Tensor *> Output;
};
class TFModelEvaluatorImpl {
public:
TFModelEvaluatorImpl(StringRef SavedModelPath,
const std::vector<std::string> &InputNames,
const std::vector<std::string> &OutputNames,
const char *Tags);
bool isValid() const { return IsValid; }
size_t OutputSize() const { return OutputFeed.size(); }
void evaluate(TF_Tensor **Output, TF_Status *Status) {
TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
nullptr, 0, nullptr, Status);
}
void initInput(size_t Index, TF_DataType Type,
const std::vector<int64_t> &Dimensions);
const std::vector<TF_Tensor *> &getInput() const { return Input; }
~TFModelEvaluatorImpl();
private:
/// The objects necessary for carrying out an evaluation of the SavedModel.
/// They are expensive to set up, and we maintain them accross all the
/// evaluations of the model.
TF_Session *Session = nullptr;
TFGraphPtr Graph;
TFSessionOptionsPtr Options;
/// The specification of the input nodes.
std::vector<TF_Output> InputFeed;
/// The input tensors. They must match by index of the corresponding InputFeed
/// value. We set up the tensors once and just mutate theirs scalars before
/// each evaluation. The input tensors keep their value after an evaluation.
std::vector<TF_Tensor *> Input;
/// The specification of the output nodes. When evaluating, the tensors in the
/// output tensor vector must match by index the corresponding element in the
/// OutputFeed.
std::vector<TF_Output> OutputFeed;
void invalidate() { IsValid = false; }
bool IsValid = true;
/// Reusable utility for ensuring we can bind the requested Name to a node in
/// the SavedModel Graph.
bool checkReportAndInvalidate(const TF_Output &Output, StringRef Name);
};
} // namespace llvm
TFModelEvaluatorImpl::TFModelEvaluatorImpl(
StringRef SavedModelPath, const std::vector<std::string> &InputNames,
const std::vector<std::string> &OutputNames, const char *Tags)
: Graph(createTFGraph()), Options(createTFSessionOptions()),
InputFeed(InputNames.size()), Input(InputNames.size()),
OutputFeed(OutputNames.size()) {
if (!ensureInitTF()) {
errs() << "Tensorflow should have been initialized";
return;
}
auto Status = createTFStatus();
Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
SavedModelPath.str().c_str(), &Tags, 1,
Graph.get(), nullptr, Status.get());
if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
errs() << TF_Message(Status.get());
invalidate();
}
for (size_t I = 0; I < InputNames.size(); ++I) {
InputFeed[I] = {
TF_GraphOperationByName(Graph.get(), (InputNames[I]).c_str()), 0};
if (!checkReportAndInvalidate(InputFeed[I], InputNames[I]))
return;
}
for (size_t I = 0; I < OutputNames.size(); ++I) {
OutputFeed[I] = {
TF_GraphOperationByName(Graph.get(), (OutputNames[I]).c_str()), 0};
if (!checkReportAndInvalidate(OutputFeed[I], OutputNames[I]))
return;
}
}
TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
const std::vector<std::string> &InputNames,
const std::vector<std::string> &OutputNames,
const char *Tags)
: Impl(new TFModelEvaluatorImpl(SavedModelPath, InputNames, OutputNames,
Tags)) {
if (!Impl->isValid())
Impl.reset();
}
TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
for (auto *T : Input) {
TF_DeleteTensor(T);
}
if (Session == nullptr)
return;
auto Status = createTFStatus();
TF_DeleteSession(Session, Status.get());
Session = nullptr;
if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
errs() << "Could not delete TF session";
}
bool TFModelEvaluatorImpl::checkReportAndInvalidate(const TF_Output &Output,
StringRef Name) {
if (Output.oper)
return true;
errs() << "Could not find TF_Output named: " + Name;
IsValid = false;
return IsValid;
}
Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
if (!isValid())
return None;
std::unique_ptr<EvaluationResultImpl> Ret =
std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
auto Status = createTFStatus();
Impl->evaluate(Ret->getOutput().data(), Status.get());
if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
errs() << TF_Message(Status.get());
Impl.reset();
return None;
}
return EvaluationResult(std::move(Ret));
}
void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
const std::vector<int64_t> &Dimensions) {
int64_t TotalSize = TF_DataTypeSize(Type);
for (auto &D : Dimensions)
TotalSize *= D;
Input[Index] =
TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
}
void *TFModelEvaluator::getUntypedInput(size_t Index) {
return TF_TensorData(Impl->getInput()[Index]);
}
TFModelEvaluator::EvaluationResult::EvaluationResult(
std::unique_ptr<EvaluationResultImpl> Impl)
: Impl(std::move(Impl)) {}
TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
: Impl(std::move(Other.Impl)) {}
void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
return TF_TensorData(Impl->getOutput()[Index]);
}
void TFModelEvaluator::initInput(size_t Index, int TypeIndex,
const std::vector<int64_t> &Dimensions) {
Impl->initInput(Index, static_cast<TF_DataType>(TypeIndex), Dimensions);
}
template <> int TFModelEvaluator::getModelTypeIndex<float>() {
return TF_FLOAT;
}
template <> int TFModelEvaluator::getModelTypeIndex<double>() {
return TF_DOUBLE;
}
template <> int TFModelEvaluator::getModelTypeIndex<int8_t>() {
return TF_INT8;
}
template <> int TFModelEvaluator::getModelTypeIndex<uint8_t>() {
return TF_UINT8;
}
template <> int TFModelEvaluator::getModelTypeIndex<int16_t>() {
return TF_INT16;
}
template <> int TFModelEvaluator::getModelTypeIndex<uint16_t>() {
return TF_UINT16;
}
template <> int TFModelEvaluator::getModelTypeIndex<int32_t>() {
return TF_INT32;
}
template <> int TFModelEvaluator::getModelTypeIndex<uint32_t>() {
return TF_UINT32;
}
template <> int TFModelEvaluator::getModelTypeIndex<int64_t>() {
return TF_INT64;
}
template <> int TFModelEvaluator::getModelTypeIndex<uint64_t>() {
return TF_UINT64;
}
TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
TFModelEvaluator::~TFModelEvaluator() {}