blob: 59f978000e66f91694eb9da5cb25c4eb0ea24210 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels.h"
#include <memory>
#include <vector>
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_buffer_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_handle.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.h"
// Required for IS_MOBILE_PLATFORM definition
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/compiler/xla/stream_executor/stream.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/tsl/framework/device_id_utils.h"
#include "tensorflow/tsl/platform/statusor.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
using tensorflow::errors::InvalidArgument;
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
// and with a focus on maintaining both source and binary compatibility.
typedef std::function<void()> AsyncOpKernelDoneCallback;
void TF_RunAsyncOpKernelDoneCallback(TF_AsyncOpKernelDoneCallback* done) {
(*reinterpret_cast<AsyncOpKernelDoneCallback*>(done))();
}
struct TF_KernelBuilder {
::tensorflow::KernelDefBuilder* cc_builder;
void* (*create_function)(TF_OpKernelConstruction*);
void (*compute_function)(void*, TF_OpKernelContext*);
void (*compute_async_function)(void*, TF_OpKernelContext*,
TF_AsyncOpKernelDoneCallback* done);
void (*delete_function)(void*);
};
TF_KernelBuilder* TF_NewKernelBuilder(
const char* op_name, const char* device_name,
void* (*create_func)(TF_OpKernelConstruction*),
void (*compute_func)(void*, TF_OpKernelContext*),
void (*delete_func)(void*)) {
TF_KernelBuilder* result = new TF_KernelBuilder;
result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
result->cc_builder->Device(device_name);
result->create_function = create_func;
result->compute_function = compute_func;
result->compute_async_function = nullptr;
result->delete_function = delete_func;
return result;
}
TF_KernelBuilder* TF_NewAsyncKernelBuilder(
const char* op_name, const char* device_name,
void* (*create_func)(TF_OpKernelConstruction*),
void (*compute_async_func)(void*, TF_OpKernelContext*,
TF_AsyncOpKernelDoneCallback* done),
void (*delete_func)(void*)) {
TF_KernelBuilder* result = new TF_KernelBuilder;
result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
result->cc_builder->Device(device_name);
result->create_function = create_func;
result->compute_function = nullptr;
result->compute_async_function = compute_async_func;
result->delete_function = delete_func;
return result;
}
void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
if (builder != nullptr) {
delete builder->cc_builder;
delete builder;
}
}
namespace tensorflow {
namespace {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
break; \
}
void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
const DataType dtype, TF_Status* status) {
// This needs to be under tensorflow:: namespace so that
// TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
switch (dtype) {
TF_CALL_ALL_TYPES(CASE);
TF_CALL_QUANTIZED_TYPES(CASE);
TF_CALL_quint16(CASE);
TF_CALL_qint16(CASE);
default:
status->status = errors::Unimplemented("Unexpected type ", dtype);
return;
}
TF_SetStatus(status, TF_OK, "");
}
#undef CASE
} // namespace
} // namespace tensorflow
namespace {
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
const tensorflow::AttrValue* attr =
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
if (attr == nullptr) {
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
"' has no attr named '", attr_name, "'.");
}
return attr;
}
} // namespace
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
const char* attr_name,
const TF_DataType type,
TF_Status* status) {
tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
}
void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
const char* arg_name) {
kernel_builder->cc_builder->HostMemory(arg_name);
}
void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
int32_t priority_number) {
kernel_builder->cc_builder->Priority(priority_number);
}
void TF_KernelBuilder_Label(TF_KernelBuilder* kernel_builder,
const char* label) {
kernel_builder->cc_builder->Label(label);
}
namespace tensorflow {
namespace {
// An OpKernel whose methods delegate to C function pointers.
class COpKernel : public OpKernel {
public:
explicit COpKernel(OpKernelConstruction* ctx,
void* (*create_func)(TF_OpKernelConstruction*),
void (*compute_func)(void*, TF_OpKernelContext*),
void (*delete_func)(void*))
: OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
if (create_func != nullptr) {
c_kernel_ =
(*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
} else {
c_kernel_ = nullptr;
}
}
void Compute(OpKernelContext* ctx) override {
(*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
}
~COpKernel() override {
if (delete_func_ != nullptr) {
(*delete_func_)(c_kernel_);
}
}
private:
void (*compute_func_)(void*, TF_OpKernelContext* context);
void (*delete_func_)(void*);
void* c_kernel_;
};
class CAsyncOpKernel : public AsyncOpKernel {
public:
explicit CAsyncOpKernel(
OpKernelConstruction* ctx, void* (*create_func)(TF_OpKernelConstruction*),
void (*compute_async_func)(void*, TF_OpKernelContext*,
TF_AsyncOpKernelDoneCallback*),
void (*delete_func)(void*))
: AsyncOpKernel(ctx),
compute_async_func_(compute_async_func),
delete_func_(delete_func) {
if (create_func != nullptr) {
c_kernel_ =
(*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
} else {
c_kernel_ = nullptr;
}
}
void Compute(OpKernelContext* ctx) override {
Notification n;
ComputeAsync(ctx, [&n]() { n.Notify(); });
n.WaitForNotification();
}
void ComputeAsync(OpKernelContext* ctx, AsyncOpKernelDoneCallback done) {
(*compute_async_func_)(
c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx),
reinterpret_cast<TF_AsyncOpKernelDoneCallback*>(&done));
}
CAsyncOpKernel* AsAsync() override { return this; }
~CAsyncOpKernel() override {
if (delete_func_ != nullptr) {
(*delete_func_)(c_kernel_);
}
}
private:
void (*compute_async_func_)(void*, TF_OpKernelContext* context,
TF_AsyncOpKernelDoneCallback* done);
void (*delete_func_)(void*);
void* c_kernel_;
};
// A KernelFactory that returns COpKernel instances.
class KernelBuilderFactory
: public ::tensorflow::kernel_factory::OpKernelFactory {
public:
explicit KernelBuilderFactory(TF_KernelBuilder* builder)
: builder_(builder) {}
::tensorflow::OpKernel* Create(
::tensorflow::OpKernelConstruction* context) override {
if (builder_->compute_function)
return new ::tensorflow::COpKernel(context, builder_->create_function,
builder_->compute_function,
builder_->delete_function);
else
return new ::tensorflow::CAsyncOpKernel(
context, builder_->create_function, builder_->compute_async_function,
builder_->delete_function);
}
~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
private:
TF_KernelBuilder* builder_;
};
} // namespace
} // namespace tensorflow
void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
TF_Status* status) {
using tensorflow::register_kernel::Name;
TF_RegisterKernelBuilderWithKernelDef(
/*serialized_kernel_def=*/nullptr, name, builder, status);
}
void TF_RegisterKernelBuilderWithKernelDef(const char* serialized_kernel_def,
const char* name,
TF_KernelBuilder* builder,
TF_Status* status) {
using tensorflow::register_kernel::Name;
if (serialized_kernel_def == nullptr) {
// If user doesn't provide a serialized KernelDef, use the kernel builder
// to build a new one.
tensorflow::kernel_factory::OpKernelRegistrar(
builder->cc_builder->Build(), name,
std::make_unique<tensorflow::KernelBuilderFactory>(builder));
TF_SetStatus(status, TF_OK, "");
return;
}
tensorflow::KernelDef* kernel_def = new tensorflow::KernelDef();
bool success = kernel_def->ParsePartialFromString(serialized_kernel_def);
if (!success) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Error parsing serialized KernelDef.");
return;
}
tensorflow::kernel_factory::OpKernelRegistrar(
kernel_def, name,
std::make_unique<tensorflow::KernelBuilderFactory>(builder));
TF_SetStatus(status, TF_OK, "");
}
// This function is only for pluggable device.
// It will return nullptr in all other cases.
// This function is experimental and subject to change.
SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"Accessing device stream is not supported on mobile. File a bug at "
"https://github.com/tensorflow/tensorflow/issues if this feature is "
"important to you");
return nullptr;
#else
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
status->status = tensorflow::errors::FailedPrecondition(
"Accessing device stream is not supported for a CPU device.");
return nullptr;
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
status->status = tensorflow::errors::FailedPrecondition(
"Accessing device stream is only supported for pluggable devices.");
return nullptr;
} else { // Is a PluggableDevice
TF_SetStatus(status, TF_OK, "");
auto c_stream = static_cast<stream_executor::CStream*>(
cc_ctx->op_device_context()->stream()->implementation());
return c_stream->Handle();
}
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
}
int TF_NumInputs(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return cc_ctx->num_inputs();
}
int TF_NumOutputs(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return cc_ctx->num_outputs();
}
void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (i < 0 || i >= cc_ctx->num_inputs()) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
return;
}
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) {
*tensor = result;
}
}
void TF_InputRange(TF_OpKernelContext* ctx, const char* name,
TF_InputRange_Args* args) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
int start = -1, stop = -1;
auto status = cc_ctx->op_kernel().InputRange(name, &start, &stop);
args->start = start;
args->stop = stop;
tensorflow::Set_TF_Status_from_Status(args->status, status);
}
TF_DataType TF_InputDatatype(TF_OpKernelContext* ctx, int index) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
CHECK_GE(index, 0); // Crash OK
CHECK_LT(index, cc_ctx->num_inputs()); // Crash OK
return static_cast<TF_DataType>(cc_ctx->input_dtype(index));
}
void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (i < 0 || i >= cc_ctx->num_outputs()) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
return;
}
::tensorflow::Tensor cc_tensor;
::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
TF_SetStatus(status, TF_OK, "");
::tensorflow::Set_TF_Status_from_Status(status, s);
if (s.ok()) {
cc_ctx->set_output(i, cc_tensor);
}
}
TF_Tensor* TF_GetMutableOutput(TF_OpKernelContext* ctx, int i,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (i < 0 || i >= cc_ctx->num_outputs()) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
return nullptr;
}
const ::tensorflow::Tensor& cc_tensor = *(cc_ctx->mutable_output(i));
TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) {
return result;
} else {
return nullptr;
}
}
void TF_GetSerializedFunctionDefLibrary(
TF_OpKernelContext* ctx, TF_Buffer* serialized_function_def_library,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
auto fdef_lib =
cc_ctx->function_library()->GetFunctionLibraryDefinition()->ToProto();
auto cc_status =
tensorflow::MessageToBuffer(fdef_lib, serialized_function_def_library);
tensorflow::Set_TF_Status_from_Status(status, cc_status);
}
void TF_GetSerializedConfigProto(TF_OpKernelContext* ctx,
TF_Buffer* serialized_config_proto,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
const tensorflow::ConfigProto* config_proto_ptr =
cc_ctx->function_library()->config_proto();
tensorflow::ConfigProto config_proto;
if (config_proto_ptr != nullptr) {
config_proto = *config_proto_ptr;
}
auto cc_status =
tensorflow::MessageToBuffer(config_proto, serialized_config_proto);
tensorflow::Set_TF_Status_from_Status(status, cc_status);
}
void TF_GetSerializedResourceHandleProto(
TF_OpKernelContext* ctx, int i, TF_Buffer* serialized_resource_handle_proto,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
const tensorflow::ResourceHandle& handle = HandleFromInput(cc_ctx, i);
tensorflow::ResourceHandleProto handle_proto;
handle.AsProto(&handle_proto);
auto cc_status = tensorflow::MessageToBuffer(
handle_proto, serialized_resource_handle_proto);
tensorflow::Set_TF_Status_from_Status(status, cc_status);
}
void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
cc_ctx->CtxFailure(s);
}
void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
cc_ctx->CtxFailure(s);
}
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
const char* attr_name,
int32_t* list_size,
int32_t* total_size,
TF_Status* status) {
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
if (!status->status.ok()) {
*list_size = -1;
*total_size = -1;
return;
}
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
*list_size = -1; \
*total_size = size_expr; \
break;
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
SINGLE_CASE(kI, TF_ATTR_INT, -1);
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
#undef SINGLE_CASE
case tensorflow::AttrValue::kList:
*list_size = 0;
*total_size = -1;
#define LIST_CASE(field, attr_type, ...) \
if (attr->list().field##_size() > 0) { \
*list_size = attr->list().field##_size(); \
__VA_ARGS__; \
break; \
}
LIST_CASE(
s, TF_ATTR_STRING, *total_size = 0;
for (int i = 0; i < attr->list().s_size();
++i) { *total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
LIST_CASE(
shape, TF_ATTR_SHAPE, *total_size = 0;
for (int i = 0; i < attr->list().shape_size(); ++i) {
const auto& s = attr->list().shape(i);
*total_size += s.unknown_rank() ? 0 : s.dim_size();
});
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
break;
case tensorflow::AttrValue::kPlaceholder:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::kFunc:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::VALUE_NOT_SET:
status->status =
InvalidArgument("Attribute '", attr_name, "' has no value set");
break;
}
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
const char* attr_name, \
c_type* val, TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
cc_type v; \
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \
::tensorflow::Set_TF_Status_from_Status(status, s); \
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
} \
void TF_OpKernelConstruction_GetAttr##func##List( \
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
int max_vals, TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
status->status = \
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
if (!status->status.ok()) return; \
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
DEFINE_TF_GETATTR(Int32, int32_t, int32_t, "int", i)
DEFINE_TF_GETATTR(Int64, int64_t, int64_t, "int", i)
DEFINE_TF_GETATTR(Float, float, float, "float", f)
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
const char* attr_name, char* value,
size_t max_length,
TF_Status* status) {
std::string v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
if (max_length <= 0) {
return;
}
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
}
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
const char* attr_name,
char** values, size_t* lengths,
int max_values, void* storage,
size_t storage_size,
TF_Status* status) {
std::vector<std::string> v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(v.size()));
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
const std::string& s = v[i];
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
status->status = InvalidArgument(
"Not enough storage to hold the requested list of strings");
return;
}
memcpy(values[i], s.data(), s.size());
p += s.size();
}
}
void TF_OpKernelConstruction_GetAttrTensor(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Tensor** val, TF_Status* status) {
*val = nullptr;
::tensorflow::Tensor t;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &t);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
*val = TF_TensorFromTensor(t, &status->status);
}
void TF_OpKernelConstruction_GetAttrTensorList(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Tensor** vals, int max_values,
TF_Status* status) {
std::vector<::tensorflow::Tensor> v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(v.size()));
for (int i = 0; i < len; ++i) {
vals[i] = TF_TensorFromTensor(v[i], &status->status);
if (!status->status.ok()) return;
}
}
TF_Buffer* TF_OpKernelConstruction_GetAttrFunction(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
tensorflow::NameAttrList function;
auto cc_status = cc_ctx->GetAttr(attr_name, &function);
if (!cc_status.ok()) {
tsl::Set_TF_Status_from_Status(status, cc_status);
return nullptr;
}
TF_Buffer* buffer = TF_NewBuffer();
cc_status = tensorflow::MessageToBuffer(function, buffer);
tsl::Set_TF_Status_from_Status(status, cc_status);
if (!cc_status.ok())
return nullptr;
else
return buffer;
}
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
const char* attr_name, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
return cc_ctx->HasAttr(attr_name);
}
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
TF_StringView string_view_of_name;
string_view_of_name.data = cc_ctx->def().name().data();
string_view_of_name.len = cc_ctx->def().name().length();
return string_view_of_name;
}
TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
CHECK_GE(i, 0);
CHECK_LT(i, cc_ctx->num_outputs());
return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
}
bool TF_IsHostMemoryInput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (i < 0 || i >= cc_ctx->num_inputs()) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
return false;
}
TF_SetStatus(status, TF_OK, "");
return cc_ctx->input_memory_type(i) == tensorflow::HOST_MEMORY;
}
bool TF_IsHostMemoryOutput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (i < 0 || i >= cc_ctx->num_outputs()) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
return false;
}
TF_SetStatus(status, TF_OK, "");
return cc_ctx->output_memory_type(i) == tensorflow::HOST_MEMORY;
}
int64_t TF_StepId(TF_OpKernelContext* ctx) {
return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
}
TF_Buffer* TF_OpKernelConstruction_GetNodeDef(TF_OpKernelConstruction* ctx,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(cc_ctx->def(), ret);
if (!status->status.ok()) {
TF_DeleteBuffer(ret);
return nullptr;
}
return ret;
}
uint64_t TF_GetFrameId(TF_OpKernelContext* ctx) {
return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
->frame_iter()
.frame_id;
}
int TF_GetGraphDefVersion(TF_OpKernelContext* ctx) {
return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
->function_library()
->graph_def_version();
}
int64_t TF_GetIterId(TF_OpKernelContext* ctx) {
return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
->frame_iter()
.iter_id;
}
int64_t TF_GetStepId(TF_OpKernelContext* ctx) {
return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
}
int TF_GetDeviceId(TF_OpKernelContext* ctx) {
// TensorFlow always sets device in OpKernelContext.
const tensorflow::DeviceBase* device_base =
reinterpret_cast<tensorflow::OpKernelContext*>(ctx)->device();
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
if (!device_base->parsed_name().has_id) return -1;
return device_base->parsed_name().id;
#else
const auto* device = reinterpret_cast<const tensorflow::Device*>(
device_base->UnderlyingDevice());
const tsl::StatusOr<int> id = tsl::GetDeviceIdFromDeviceParsedName(
device->parsed_name(), tensorflow::DeviceType(device->device_type()));
if (!id.ok()) return -1;
return *id;
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
}
TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx) {
auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
TF_StringView opkernel_name_sv;
opkernel_name_sv.data = cc_ctx->op_kernel().name().data();
opkernel_name_sv.len = cc_ctx->op_kernel().name().length();
return opkernel_name_sv;
}
TF_StringView TF_GetResourceMgrDefaultContainerName(TF_OpKernelContext* ctx) {
auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
TF_StringView default_container_name_sv;
default_container_name_sv.data =
cc_ctx->resource_manager()->default_container().data();
default_container_name_sv.len =
cc_ctx->resource_manager()->default_container().length();
return default_container_name_sv;
}
TF_StringView TF_GetOpKernelRequestedInput(TF_OpKernelContext* ctx,
size_t index) {
auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
TF_StringView requested_input_sv;
requested_input_sv.data = cc_ctx->op_kernel().requested_input(index).data();
requested_input_sv.len = cc_ctx->op_kernel().requested_input(index).length();
return requested_input_sv;
}
TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
tensorflow::gtl::ArraySlice<const int64_t> dimarray(
reinterpret_cast<const int64_t*>(dims), num_dims);
tensorflow::Tensor* tensor;
tensorflow::Status s = cc_ctx->allocate_output(
index, tensorflow::TensorShape(dimarray), &tensor);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor;
}
TF_Tensor* TF_ForwardInputOrAllocateOutput(
TF_OpKernelContext* context, const int* candidate_input_indices,
int num_candidate_input_indices, int output_index,
const int64_t* output_dims, int output_num_dims, int* forwarded_input,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
tensorflow::gtl::ArraySlice<int> input_indices_array(
candidate_input_indices, num_candidate_input_indices);
tensorflow::gtl::ArraySlice<const int64_t> output_dimarray(
reinterpret_cast<const int64_t*>(output_dims), output_num_dims);
tensorflow::Tensor* output_tensor_pointer;
tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
input_indices_array, output_index,
tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
forwarded_input);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor_output;
}
TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
const int64_t* dims, int num_dims,
TF_AllocatorAttributes* attributes,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
TF_SetStatus(status, TF_OK, "");
tensorflow::gtl::ArraySlice<const int64_t> dimarray(
reinterpret_cast<const int64_t*>(dims), num_dims);
if (attributes && !attributes->struct_size) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"TF_AllocatorAttributes struct "
"size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
return nullptr;
}
tensorflow::AllocatorAttributes allocator_attr;
if (attributes && attributes->on_host) {
allocator_attr.set_on_host(true);
}
tensorflow::Status s;
tensorflow::Tensor tensor;
s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimarray), &tensor,
allocator_attr);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor;
tf_tensor = TF_TensorFromTensor(tensor, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor;
}