blob: 259d1cac9df4d50a99dd6ef5ce6ca5d2f07c2961 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels_experimental.h"
#include <algorithm>
#include <optional>
#include <string>
#include <utility>
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/ref_var.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#ifndef IS_MOBILE_PLATFORM
#include "tensorflow/core/kernels/data/optional_ops_util.h"
#include "tensorflow/core/kernels/tensor_list.h"
#include "tensorflow/core/kernels/tensor_list_util.h"
#include "tensorflow/core/kernels/variant_ops_util.h"
#include "tensorflow/core/platform/abi.h"
#endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/refcount.h"
using tensorflow::AllocatorAttributes;
using tensorflow::mutex_lock;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::TF_TensorFromTensor;
using tensorflow::Var;
using tensorflow::Variant;
using tensorflow::errors::InvalidArgument;
struct TF_VariableInputLockHolder {
TF_VariableInputLockHolder(
std::vector<tensorflow::Var*> vars,
std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks,
std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks)
: vars(std::move(vars)),
locks(std::move(locks)),
shared_locks(std::move(shared_locks)) {}
std::vector<tensorflow::Var*> vars;
std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks;
std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks;
};
tensorflow::Status EnsureSparseVariableAccess(
TF_OpKernelContext* ctx, bool variantType,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest),
tensorflow::Var* var, bool lock_held = false) {
auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (var->copy_on_read_mode.load()) {
return ::tensorflow::OkStatus();
}
std::optional<mutex_lock> ml;
if (!lock_held) {
ml.emplace(*var->mu());
}
// Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
// also happen if there are no concurrent reads of the variable and
// copy-on-read mode is false.
if (var->tensor()->RefCountIsOne()) {
var->copy_on_read_mode.store(true);
return ::tensorflow::OkStatus();
}
Tensor tmp;
if (variantType) {
AllocatorAttributes attr;
attr.set_on_host(true);
TF_RETURN_IF_ERROR(context->allocate_temp(
var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr));
const auto elements_in = var->tensor()->flat<Variant>();
auto elements_out = tmp.flat<Variant>();
for (int64_t i = 0; i < elements_in.size(); ++i) {
elements_out(i) = elements_in(i);
}
} else {
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
attr.set_nic_compatible(true);
TF_RETURN_IF_ERROR(context->allocate_temp(
var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr));
tensorflow::Status s;
TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
TF_Tensor* tf_tensor = TF_TensorFromTensor(*var->tensor(), &s);
copyFunc(ctx, tf_tensor, tf_tmp);
}
*var->tensor() = tmp;
var->copy_on_read_mode.store(true);
return ::tensorflow::OkStatus();
}
tensorflow::Status PrepareToUpdateVariable(
TF_OpKernelContext* ctx, tensorflow::Tensor* tensor, bool copy_on_read_mode,
bool variantType,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest)) {
auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (copy_on_read_mode || !tensor->RefCountIsOne()) {
// Tensor's buffer is in use by some read, so we need to copy before
// updating.
Tensor tmp;
if (variantType) {
AllocatorAttributes attr;
attr.set_on_host(true);
TF_RETURN_IF_ERROR(
context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
const auto elements_in = tensor->flat<Variant>();
auto elements_out = tmp.flat<Variant>();
for (int64_t i = 0; i < elements_in.size(); ++i) {
elements_out(i) = elements_in(i);
}
} else {
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
attr.set_nic_compatible(true);
TF_RETURN_IF_ERROR(
context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
tensorflow::Status s;
TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
copyFunc(ctx, tf_tensor, tf_tmp);
}
*tensor = tmp;
}
return ::tensorflow::OkStatus();
}
tensorflow::mutex* GetTrainingVariableMutex(TF_OpKernelContext* ctx,
int32_t input,
tensorflow::Var** maybe_resource) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
*maybe_resource = nullptr;
if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) {
if (LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), maybe_resource)
.ok()) {
return (*maybe_resource)->mu();
} else {
cc_ctx->CtxFailureWithWarning(
tensorflow::errors::Internal("Invalid variable reference."));
return nullptr;
}
}
return cc_ctx->input_ref_mutex(input);
}
void TF_AssignVariable(TF_OpKernelContext* ctx, int input_index,
int value_index, bool validate_shape,
void (*copyFunc)(TF_OpKernelContext* ctx,
TF_Tensor* source, TF_Tensor* dest),
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
tensorflow::core::RefCountPtr<tensorflow::Var> variable;
const tensorflow::Tensor& value = cc_ctx->input(value_index);
OP_REQUIRES_OK(cc_ctx, tensorflow::LookupOrCreateResource<tensorflow::Var>(
cc_ctx, HandleFromInput(cc_ctx, input_index),
&variable, [&value](tensorflow::Var** ptr) {
*ptr = new tensorflow::Var(value.dtype());
*(*ptr)->tensor() = value;
(*ptr)->is_initialized = true;
return ::tensorflow::OkStatus();
}));
tensorflow::mutex_lock ml(*variable->mu());
if (validate_shape) {
OP_REQUIRES(cc_ctx,
(!variable->is_initialized ||
variable->tensor()->shape().IsSameSize(value.shape())),
InvalidArgument(
"Trying to assign to variable with tensor with wrong shape."
" Expected ",
variable->tensor()->shape().DebugString(), " got ",
value.shape().DebugString()));
}
if (variable->copy_on_read_mode.load()) {
tensorflow::Tensor tmp;
tensorflow::AllocatorAttributes attr;
attr.set_gpu_compatible(true);
attr.set_nic_compatible(true);
OP_REQUIRES_OK(cc_ctx, cc_ctx->allocate_temp(value.dtype(), value.shape(),
&tmp, attr));
tensorflow::Status s;
TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
TF_Tensor* tf_value = TF_TensorFromTensor(value, &s);
copyFunc(ctx, tf_value, tf_tmp);
*variable->tensor() = tmp;
} else {
*variable->tensor() = value;
}
variable->is_initialized = true;
TF_SetStatus(status, TF_OK, "");
}
void TF_AssignRefVariable(TF_OpKernelContext* ctx, int input_ref_index,
int output_ref_index, int value_index,
bool use_locking, bool validate_shape,
void (*copyFunc)(TF_OpKernelContext* ctx,
TF_Tensor* source, TF_Tensor* dest),
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
auto copy = [copyFunc, ctx](::tensorflow::OpKernelContext* cc_ctx,
::tensorflow::Tensor* lhs,
const ::tensorflow::Tensor& rhs) {
::tensorflow::Status s;
TF_Tensor* tf_lhs = TF_TensorFromTensor(*lhs, &s);
OP_REQUIRES_OK(cc_ctx, s);
TF_Tensor* tf_rhs = TF_TensorFromTensor(rhs, &s);
if (!s.ok()) {
TF_DeleteTensor(tf_lhs);
OP_REQUIRES_OK(cc_ctx, s);
}
copyFunc(ctx, tf_rhs, tf_lhs);
};
::tensorflow::AssignRefVariable(cc_ctx, input_ref_index, output_ref_index,
value_index, use_locking, validate_shape,
false, copy);
TF_SetStatus(status, TF_OK, "");
}
void TF_AssignUpdateVariable(TF_OpKernelContext* ctx, int input_index,
int value_index, int Op, int isVariantType,
void (*copyFunc)(TF_OpKernelContext* ctx,
TF_Tensor* source,
TF_Tensor* dest),
void (*updateFunc)(TF_OpKernelContext* ctx,
TF_Tensor* tensor,
TF_Tensor* value, int Op),
TF_Status* tf_status) {
auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
tensorflow::core::RefCountPtr<Var> variable;
Status status =
LookupResource(context, HandleFromInput(context, input_index), &variable);
if (!status.ok()) {
printf("Failed with error: %s\n", tsl::NullTerminatedMessage(status));
abort();
}
const Tensor& value = context->input(value_index);
mutex_lock ml(*variable->mu());
Tensor* var_tensor = variable->tensor();
OP_REQUIRES(
context, var_tensor->shape().IsSameSize(value.shape()),
InvalidArgument("Cannot update variable with shape ",
var_tensor->shape().DebugString(),
" using a Tensor with shape ",
value.shape().DebugString(), ", shapes must be equal."));
OP_REQUIRES_OK(context,
PrepareToUpdateVariable(ctx, var_tensor,
variable->copy_on_read_mode.load(),
isVariantType, copyFunc));
tensorflow::Status s;
TF_Tensor* tf_var_tensor = TF_TensorFromTensor(*var_tensor, &s);
TF_Tensor* tf_value = TF_TensorFromTensor(value, &s);
updateFunc(ctx, tf_var_tensor, tf_value, Op);
TF_SetStatus(tf_status, TF_OK, "");
}
void TF_MaybeLockVariableInputMutexesInOrder(
TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs,
size_t len,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest),
TF_VariableInputLockHolder** lockHolder, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
bool any_resource = false;
std::vector<int> input_ids(inputs, inputs + len);
for (auto i : input_ids) {
if (cc_ctx->input_dtype(i) == tensorflow::DT_RESOURCE) {
any_resource = true;
break;
}
}
if (!do_lock && !any_resource) {
*lockHolder = new TF_VariableInputLockHolder({}, {}, {});
TF_SetStatus(status, TF_OK, "");
return;
}
std::vector<tensorflow::Var*> vars;
std::vector<tensorflow::mutex*> mutexes;
std::vector<int32_t> acquire_order;
for (auto input : input_ids) {
tensorflow::Var* var;
tensorflow::mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
mutexes.push_back(mutex);
}
}
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
auto locks = absl::make_unique<std::vector<tensorflow::mutex_lock>>();
auto shared_locks =
absl::make_unique<std::vector<tensorflow::tf_shared_lock>>();
locks->reserve(acquire_order.size());
for (auto acquire : acquire_order) {
tensorflow::mutex* mu = mutexes[acquire];
if (mu != nullptr) {
if (do_lock) {
locks->emplace_back(*mu);
} else {
shared_locks->emplace_back(*mu);
}
}
}
*lockHolder = new TF_VariableInputLockHolder(vars, std::move(locks),
std::move(shared_locks));
if (sparse) {
// Enable sparse variables' access.
// NOTE: This can not be done before the variable input locks are held,
// because a race condition can happen between this and another thread that
// turns off some variable's `copy_on_read_mode` after this thread enables
// sparse access; when a later function sees `copy_on_read_mode` is off, it
// will try to lock the variable again for updating `copy_on_read_mode` and
// cause the deadlock, since the variable mutex is non-re-entrant.
for (auto* var : vars) {
TF_CHECK_OK(EnsureSparseVariableAccess(
ctx, /*variantType=*/false, copyFunc, var, /*lock_held=*/true));
}
}
TF_SetStatus(status, TF_OK, "");
}
void TF_GetInputTensorFromVariable(TF_OpKernelContext* ctx, int input,
bool lock_held, bool isVariantType,
bool sparse,
void (*copyFunc)(TF_OpKernelContext* ctx,
TF_Tensor* source,
TF_Tensor* dest),
TF_Tensor** out, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
auto status_setter = ::tensorflow::gtl::MakeCleanup([cc_ctx, status]() {
::tensorflow::Set_TF_Status_from_Status(status, cc_ctx->status());
});
tensorflow::Status s;
if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) {
tensorflow::core::RefCountPtr<tensorflow::Var> var;
OP_REQUIRES_OK(
cc_ctx, LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), &var));
if (sparse) {
OP_REQUIRES_OK(cc_ctx, EnsureSparseVariableAccess(ctx, isVariantType,
copyFunc, var.get()));
*out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s);
OP_REQUIRES_OK(cc_ctx, s);
return;
}
OP_REQUIRES_OK(cc_ctx, PrepareToUpdateVariable(
ctx, var->tensor(),
var->copy_on_read_mode.load(), false, copyFunc));
*out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s);
OP_REQUIRES_OK(cc_ctx, s);
return;
}
*out = ::tensorflow::TF_TensorFromTensor(
cc_ctx->mutable_input(input, lock_held), &s);
OP_REQUIRES_OK(cc_ctx, s);
}
void TF_OpKernelContext_ForwardRefInputToRefOutput(TF_OpKernelContext* ctx,
int32_t input_index,
int32_t output_index) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (cc_ctx->input_dtype(input_index) != tensorflow::DT_RESOURCE) {
cc_ctx->forward_ref_input_to_ref_output(input_index, output_index);
}
}
void TF_ReleaseVariableInputLockHolder(TF_VariableInputLockHolder* lockHolder) {
if (lockHolder != nullptr) {
lockHolder->locks.reset();
for (tensorflow::Var* var : lockHolder->vars) {
var->Unref();
}
delete lockHolder;
}
}
void TF_GetInputByName(TF_OpKernelContext* ctx, const char* inputName,
TF_Tensor** tensor, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
const ::tensorflow::Tensor* cc_tensor = nullptr;
tensorflow::Status s = cc_ctx->input(inputName, &cc_tensor);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return;
}
TF_Tensor* result =
::tensorflow::TF_TensorFromTensor(*cc_tensor, &status->status);
if (TF_GetCode(status) == TF_OK) {
*tensor = result;
}
}
void TF_OpKernelConstruction_GetAttrTensorShape(TF_OpKernelConstruction* ctx,
const char* attr_name,
int64_t* dims, size_t num_dims,
TF_Status* status) {
::tensorflow::TensorShape shape;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &shape);
::tensorflow::Set_TF_Status_from_Status(status, s);
size_t rank = static_cast<size_t>(shape.dims());
if (!status->status.ok()) return;
if (num_dims != rank) {
status->status = InvalidArgument("Expected rank is ", num_dims,
" but actual rank is ", rank);
return;
}
for (int i = 0; i < rank; ++i) {
dims[i] = static_cast<int64_t>(shape.dim_size(i));
}
}
bool TF_IsRefInput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (i < 0 || i >= cc_ctx->num_inputs()) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
return false;
}
TF_SetStatus(status, TF_OK, "");
return cc_ctx->input_is_ref(i);
}
#ifndef IS_MOBILE_PLATFORM
template <typename T>
static Status ValidateVariantType(const Variant& variant) {
if (variant.get<T>() == nullptr) {
const std::string type_index_name =
::tensorflow::port::MaybeAbiDemangle(variant.TypeId().name());
return ::tensorflow::errors::Internal(
"VariantBinaryOpFn: Could not access object 'a', type_index: ",
type_index_name);
}
return ::tensorflow::OkStatus();
}
static Status VariantBinaryAddFunc(
::tensorflow::OpKernelContext* cc_ctx, const Variant& a, const Variant& b,
Variant* out,
void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b,
TF_Tensor* out));
static Status CCBinaryAddFunc(
::tensorflow::OpKernelContext* cc_ctx, const Tensor& cc_a,
const Tensor& cc_b, Tensor* cc_out,
void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b,
TF_Tensor* out)) {
if (cc_a.dtype() == ::tensorflow::DT_INVALID) {
*cc_out = cc_b;
return ::tensorflow::OkStatus();
}
if (cc_b.dtype() == ::tensorflow::DT_INVALID) {
*cc_out = cc_a;
return ::tensorflow::OkStatus();
}
Status status;
TF_Tensor* a = TF_TensorFromTensor(cc_a, &status);
TF_RETURN_IF_ERROR(status);
TF_Tensor* b = TF_TensorFromTensor(cc_b, &status);
if (!status.ok()) {
TF_DeleteTensor(a);
return status;
}
::tensorflow::AllocatorAttributes attr;
if (cc_a.dtype() == ::tensorflow::DT_VARIANT) {
attr.set_on_host(true);
}
status = cc_ctx->allocate_temp(cc_a.dtype(), cc_a.shape(), cc_out, attr);
if (!status.ok()) {
TF_DeleteTensor(a);
TF_DeleteTensor(b);
return status;
}
TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status);
if (!status.ok()) {
TF_DeleteTensor(a);
TF_DeleteTensor(b);
return status;
}
auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx);
if (cc_a.dtype() == ::tensorflow::DT_VARIANT) {
return VariantBinaryAddFunc(
cc_ctx, cc_a.scalar<Variant>()(), cc_b.scalar<Variant>()(),
cc_out->scalar<Variant>().data(), binary_add_func);
} else {
binary_add_func(ctx, a, b, out);
return cc_ctx->status();
}
};
static Status VariantBinaryAddFunc(
::tensorflow::OpKernelContext* cc_ctx, const Variant& a, const Variant& b,
Variant* out,
void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b,
TF_Tensor* out)) {
auto cc_binary_add = [binary_add_func](::tensorflow::OpKernelContext* cc_ctx,
const Tensor& cc_a, const Tensor& cc_b,
Tensor* cc_out) {
return CCBinaryAddFunc(cc_ctx, cc_a, cc_b, cc_out, binary_add_func);
};
if (out == nullptr) {
return ::tensorflow::errors::Internal(
"The output variant hasn't been initialized");
}
if (a.TypeId() != b.TypeId()) {
return ::tensorflow::errors::Internal(
"BinaryOpVariants: Variants a and b have different "
"type ids. Type names: '",
a.TypeName(), "' vs. '", b.TypeName(), "'");
}
if (a.TypeId() == tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) {
TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(a));
*out = ::tensorflow::TensorList();
return ::tensorflow::TensorListBinaryAdd(
cc_ctx, *a.get<::tensorflow::TensorList>(),
*b.get<::tensorflow::TensorList>(),
out->get<::tensorflow::TensorList>(), cc_binary_add);
} else if (a.TypeId() == tensorflow::TypeIndex::Make<
::tensorflow::data::OptionalVariant>()) {
TF_RETURN_IF_ERROR(
ValidateVariantType<::tensorflow::data::OptionalVariant>(a));
*out = ::tensorflow::data::OptionalVariant();
return ::tensorflow::data::OptionalBinaryAdd(
cc_ctx, *a.get<::tensorflow::data::OptionalVariant>(),
*b.get<::tensorflow::data::OptionalVariant>(),
out->get<::tensorflow::data::OptionalVariant>(), cc_binary_add);
}
const std::string type_index_name =
::tensorflow::port::MaybeAbiDemangle(a.TypeId().name());
return ::tensorflow::errors::Internal(
"No unary variant binary_op function found for op ADD Variant "
"type_name: ",
type_index_name, " for device type: ", cc_ctx->device()->name());
}
void TF_AddNVariant(TF_OpKernelContext* ctx,
void (*binary_add_func)(TF_OpKernelContext* ctx,
TF_Tensor* a, TF_Tensor* b,
TF_Tensor* out),
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
auto binary_add_variant =
[binary_add_func](::tensorflow::OpKernelContext* cc_ctx, const Variant& a,
const Variant& b, Variant* out) {
return VariantBinaryAddFunc(cc_ctx, a, b, out, binary_add_func);
};
::tensorflow::AddNVariant(cc_ctx, binary_add_variant);
::tensorflow::Set_TF_Status_from_Status(status, cc_ctx->status());
}
static Status ZerosLikeVariant(::tensorflow::OpKernelContext* cc_ctx,
const Variant& input, Variant* out,
void (*zeros_like_func)(TF_OpKernelContext* ctx,
TF_Tensor* input,
TF_Tensor* out)) {
auto cc_zeros_like_func = [zeros_like_func](
::tensorflow::OpKernelContext* cc_ctx,
const Tensor& cc_input, Tensor* cc_out) {
AllocatorAttributes attr;
if (cc_input.dtype() == ::tensorflow::DT_VARIANT) {
attr.set_on_host(true);
}
TF_RETURN_IF_ERROR(cc_ctx->allocate_temp(cc_input.dtype(), cc_input.shape(),
cc_out, attr));
switch (cc_input.dtype()) {
case ::tensorflow::DT_INVALID: {
*cc_out = Tensor(::tensorflow::DT_INVALID);
break;
}
case ::tensorflow::DT_VARIANT: {
// If the wrapped tensor is also a variant, recursively call
// ZerosLikeVariant to unwrap it the same way
Variant* out_variant = cc_out->scalar<Variant>().data();
TF_RETURN_IF_ERROR(ZerosLikeVariant(cc_ctx,
cc_input.scalar<Variant>()(),
out_variant, zeros_like_func));
break;
}
default: {
Status status;
TF_Tensor* input = TF_TensorFromTensor(cc_input, &status);
TF_RETURN_IF_ERROR(status);
TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status);
if (!status.ok()) {
TF_DeleteTensor(input);
return status;
}
auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx);
zeros_like_func(ctx, input, out);
}
}
return cc_ctx->status();
};
if (out == nullptr) {
return ::tensorflow::errors::Internal(
"The output variant hasn't been initialized");
}
if (input.TypeId() ==
tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) {
TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(input));
*out = ::tensorflow::TensorList();
return ::tensorflow::TensorListZerosLike(
cc_ctx, *input.get<::tensorflow::TensorList>(),
out->get<::tensorflow::TensorList>(), cc_zeros_like_func);
} else if (input.TypeId() == tensorflow::TypeIndex::Make<
::tensorflow::data::OptionalVariant>()) {
TF_RETURN_IF_ERROR(
ValidateVariantType<::tensorflow::data::OptionalVariant>(input));
*out = ::tensorflow::data::OptionalVariant();
return ::tensorflow::data::OptionalZerosLike(
cc_ctx, *input.get<::tensorflow::data::OptionalVariant>(),
out->get<::tensorflow::data::OptionalVariant>(), cc_zeros_like_func);
}
const std::string type_index_name =
::tensorflow::port::MaybeAbiDemangle(input.TypeId().name());
return ::tensorflow::errors::Internal(
"No unary variant unary_op function found for op ZEROS_LIKE Variant "
"type_name: ",
type_index_name, " for device type: ", cc_ctx->device()->name());
}
void TF_ZerosLikeVariant(TF_OpKernelContext* ctx,
void (*zeros_like_func)(TF_OpKernelContext* ctx,
TF_Tensor* input,
TF_Tensor* out),
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
const Tensor& input = cc_ctx->input(0);
OP_REQUIRES(cc_ctx, input.dims() == 0,
InvalidArgument(
"ZerosLike non-scalar Tensor with dtype=DT_VARIANT is not "
"supported."));
const Variant& v = input.scalar<Variant>()();
// DT_VARIANT tensors must be allocated on CPU since they wrap C++
// objects which can not be efficiently represented in GPU memory.
int numa_node = cc_ctx->device()->NumaNode();
Tensor out(::tensorflow::cpu_allocator(numa_node), ::tensorflow::DT_VARIANT,
::tensorflow::TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
Status cc_status = ZerosLikeVariant(cc_ctx, v, out_v, zeros_like_func);
::tensorflow::Set_TF_Status_from_Status(status, cc_status);
OP_REQUIRES_OK(cc_ctx, cc_status);
cc_ctx->set_output(0, out);
}
#endif // IS_MOBILE_PLATFORM