blob: 2fd92bd7dc0546c3164f98faf8fec0dad8114b1a [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_buffer_internal.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/base64.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/util/debug_data_dumper.h"
using tensorflow::errors::InvalidArgument;
namespace tensorflow {
namespace {
Status ValidateNonRefOutput(const Node* node, int idx) {
const DataType& dt = node->output_type(idx);
return IsRefType(dt)
? InvalidArgument("Output ", idx, " of node '", node->name(),
"' has a reference type ", DataTypeString(dt))
: OkStatus();
}
// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
// does various checks while doing so. `input_nodes` will contain the same
// information as input_tensors just in a different structure to make
// following processing easier. TODO(iga): Simplify this nested structure.
Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing input ", i, " into function '", fn_name,
"'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while processing input ", i,
" into function '", fn_name, "'");
input_tensors->emplace_back(node, idx);
const auto& iter = input_nodes->find(node);
if (iter == input_nodes->end()) {
input_nodes->insert({node, {idx}});
} else {
auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
return InvalidArgument("TF_Output ", node->name(), ":", idx,
" appears more than once in the input list");
}
indices.push_back(idx);
}
}
return OkStatus();
}
// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
// checks while doing so.
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing output ", i, " from function '", fn_name,
"'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while creating function '",
fn_name, "'");
output_tensors->emplace_back(node, idx);
}
return OkStatus();
}
// Populates `body_nodes` with the nodes that will become function's body.
// Performs various checks.
Status ComputeBodyNodes(
const TF_Graph* fn_body, const char* fn_name, int num_opers,
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);
if (iter == input_nodes.end()) {
// This node is not referenced in inputs. Add it to the body.
body_nodes->push_back(node);
} else {
// This node is referenced in inputs. Currently, we place an
// artificial restriction and require that when num_opers=-1, such
// nodes must have a single output.
if (node->num_outputs() != 1) {
return InvalidArgument(
"When `num_opers` is set to -1, nodes referenced in `inputs` "
"must have a single output. Node ",
node->name(), " has ", node->num_outputs(),
" outputs. Encountered while creating function '", fn_name, "'");
}
}
}
} else {
body_nodes->reserve(num_opers);
for (int i = 0; i < num_opers; ++i) {
const Node* node = &opers[i]->node;
body_nodes->push_back(node);
}
}
return OkStatus();
}
} // namespace
} // namespace tensorflow
using tensorflow::Node;
using tensorflow::string;
TF_Function* TF_GraphToFunctionWithControlOutputs(
const TF_Graph* fn_body, const char* fn_name,
unsigned char append_hash_to_fn_name, int num_opers,
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs, const char* const* output_names,
int ncontrol_outputs, const TF_Operation* const* control_outputs,
const char* const* control_output_names, const TF_FunctionOptions* opts,
const char* description, TF_Status* status) {
tensorflow::mutex_lock l(fn_body->mu);
// Process inputs.
std::vector<tensorflow::OutputTensor> input_tensors;
std::unordered_map<const Node*, std::vector<int>> input_nodes;
status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
&input_tensors, &input_nodes);
if (TF_GetCode(status) != TF_OK) return nullptr;
// Process outputs.
std::vector<tensorflow::OutputTensor> output_tensors;
status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
outputs, &output_tensors);
if (TF_GetCode(status) != TF_OK) return nullptr;
// Process output names.
std::vector<string> output_names_vec;
if (output_names) {
output_names_vec.reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names_vec.push_back(string(output_names[i]));
}
}
// Process control output names.
std::vector<string> control_output_names_vec;
if (control_output_names) {
control_output_names_vec.reserve(ncontrol_outputs);
for (int i = 0; i < ncontrol_outputs; ++i) {
control_output_names_vec.push_back(string(control_output_names[i]));
}
}
// Compute body nodes.
std::vector<const Node*> body_nodes;
status->status = tensorflow::ComputeBodyNodes(
fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
if (TF_GetCode(status) != TF_OK) return nullptr;
// Compute body nodes.
std::vector<const Node*> control_output_nodes;
control_output_nodes.reserve(ncontrol_outputs);
for (int i = 0; i < ncontrol_outputs; ++i) {
control_output_nodes.push_back(&control_outputs[i]->node);
}
// Do the actual function creation.
DCHECK(append_hash_to_fn_name <= 1);
tensorflow::FunctionDef fdef;
status->status = tensorflow::GraphToFunctionDef(
fn_body->graph, fn_name, append_hash_to_fn_name != 0,
/*set_stateful_from_nodes=*/true,
/*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
output_tensors, output_names_vec, control_output_nodes,
control_output_names_vec, description, &fdef);
if (TF_GetCode(status) != TF_OK) {
return nullptr;
}
// Dump the op creation stacktraces for debugging purpose.
DEBUG_DATA_DUMPER()->DumpOpCreationStackTraces(
fn_name, kDebugGroupOpStacktrace, "initial", &fn_body->graph);
tensorflow::StackTracesMap stack_traces;
for (const Node* n : fn_body->graph.nodes()) {
stack_traces[n->name()] = n->GetStackTrace();
}
TF_Function* tf_function = new TF_Function();
tf_function->record = new tensorflow::FunctionRecord(
std::move(fdef), std::move(stack_traces), false);
return tf_function;
}
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
unsigned char append_hash_to_fn_name,
int num_opers, const TF_Operation* const* opers,
int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs,
const char* const* output_names,
const TF_FunctionOptions* opts,
const char* description, TF_Status* status) {
return TF_GraphToFunctionWithControlOutputs(
fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
description, status);
}
const char* TF_FunctionName(TF_Function* func) {
return func->record->fdef().signature().name().c_str();
}
void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
const TF_Function* grad, TF_Status* status) {
if (func == nullptr) {
status->status = InvalidArgument(
"'func' argument to TF_GraphCopyFunction cannot be null");
return;
}
tensorflow::mutex_lock l(g->mu);
status->status = g->graph.AddFunctionDef(func->record->fdef(),
func->record->stack_traces());
if (TF_GetCode(status) != TF_OK) return;
if (!grad) return;
status->status = g->graph.AddFunctionDef(grad->record->fdef(),
grad->record->stack_traces());
if (TF_GetCode(status) != TF_OK) return;
tensorflow::GradientDef gdef;
gdef.set_function_name(func->record->fdef().signature().name());
gdef.set_gradient_func(grad->record->fdef().signature().name());
status->status = g->graph.AddGradientDef(std::move(gdef));
}
int TF_GraphNumFunctions(TF_Graph* g) {
tensorflow::mutex_lock l(g->mu);
return g->graph.flib_def().num_functions();
}
int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
TF_Status* status) {
tensorflow::FunctionDefLibrary lib;
{
tensorflow::mutex_lock l(g->mu);
lib = g->graph.flib_def().ToProto();
}
const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
for (int i = 0; i < len; ++i) {
TF_Function* func = new TF_Function();
func->record = new tensorflow::FunctionRecord(lib.function(i), {}, false);
funcs[i] = func;
}
status->status = ::tensorflow::OkStatus();
return len;
}
void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
TF_Status* status) {
status->status = MessageToBuffer(func->record->fdef(), output_func_def);
}
TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::FunctionDef fdef;
bool success = fdef.ParseFromArray(proto, proto_len);
if (!success) {
status->status = InvalidArgument(
"Invalid FunctionDef given to TF_FunctionImportFunctionDef");
return nullptr;
}
TF_Function* func = new TF_Function();
func->record = new tensorflow::FunctionRecord(std::move(fdef), {}, false);
status->status = ::tensorflow::OkStatus();
return func;
}
void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::AttrValue attr_value;
if (!attr_value.ParseFromArray(proto, proto_len)) {
status->status = InvalidArgument(
"Unparseable AttrValue proto passed to "
"TF_FunctionSetAttrValueProto");
return;
}
auto fdef_or = func->record->mutable_fdef();
if (!fdef_or.ok()) {
status->status = fdef_or.status();
return;
}
(*(fdef_or.value()->mutable_attr()))[string(attr_name)] = attr_value;
status->status = ::tensorflow::OkStatus();
}
void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
TF_Buffer* output_attr_value,
TF_Status* status) {
const auto& it = func->record->fdef().attr().find(attr_name);
if (it == func->record->fdef().attr().end()) {
status->status =
InvalidArgument("Function '", func->record->fdef().signature().name(),
"' has no attr named '", attr_name, "'.");
return;
}
status->status = MessageToBuffer(it->second, output_attr_value);
}
void TF_DeleteFunction(TF_Function* func) {
if (func == nullptr) {
return;
}
func->record->Unref();
func->record = nullptr;
delete func;
}