blob: a36ea55e311f37d34a48d2ba2a2891639d29e391 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
#define TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/kernels.h"
// --------------------------------------------------------------------------
// Experimental kernel C API for TensorFlow.
//
// The API here is subject to changes in the future.
// --------------------------------------------------------------------------
#ifdef __cplusplus
extern "C" {
#endif
typedef struct TF_VariableInputLockHolder TF_VariableInputLockHolder;
// Expose higher level Assignment operation for Pluggable vendors to implement
// in the plugin for Training. The API takes in the context with indices for
// the input and value tensors. It also accepts the copy callback provided by
// pluggable vendor to do the copying of the tensors. The caller takes ownership
// of the `source` and `dest` tensors and is responsible for freeing them with
// TF_DeleteTensor. This function will return an error when the following
// conditions are met:
// 1. `validate_shape` is set to `true`
// 2. The variable is initialized
// 3. The shape of the value tensor doesn't match the shape of the variable
// tensor.
TF_CAPI_EXPORT extern void TF_AssignVariable(
TF_OpKernelContext* ctx, int input_index, int value_index,
bool validate_shape,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest),
TF_Status* status);
// Expose higher level Assignment operation for Pluggable vendors to implement
// in the plugin for Training on ref variables. The API takes in the context
// with indices for the input and value tensors. It also accepts the copy
// callback provided by pluggable vendor to do the copying of the tensors. The
// caller takes ownership of the `source` and `dest` tensors and is responsible
// for freeing them with TF_DeleteTensor.
TF_CAPI_EXPORT extern void TF_AssignRefVariable(
TF_OpKernelContext* ctx, int input_ref_index, int output_ref_index,
int value_index, bool use_locking, bool validate_shape,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest),
TF_Status* status);
// Expose higher level AssignUpdate operation for Pluggable vendors to implement
// in the plugin for Training. The API takes in the context with indices for the
// input and value tensors. It also accepts the copy callback provided by
// pluggable vendor to do the copying of the tensors and the update callback to
// apply the arithmetic operation. The caller takes ownership of the `source`,
// `dest`, `tensor` and `value` tensors and is responsible for freeing them with
// TF_DeleteTensor.
TF_CAPI_EXPORT extern void TF_AssignUpdateVariable(
TF_OpKernelContext* ctx, int input_index, int value_index, int Op,
int isVariantType,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest),
void (*updateFunc)(TF_OpKernelContext* ctx, TF_Tensor* tensor,
TF_Tensor* value, int Op),
TF_Status* status);
// This is a helper function which acquires mutexes in-order to provide
// thread-safe way of performing weights update during the optimizer op. It
// returns an opaque LockHolder handle back to plugin. This handle is passed to
// the Release API for releasing the locks when the weight update is done. The
// caller takes ownership of the `source` and `dest` tensors and is responsible
// for freeing them with TF_DeleteTensor.
TF_CAPI_EXPORT extern void TF_MaybeLockVariableInputMutexesInOrder(
TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs,
size_t len,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest),
TF_VariableInputLockHolder** lockHolder, TF_Status* status);
// This interface returns `out` tensor which is updated corresponding to the
// variable passed with input index. The caller takes ownership of the `source`
// and `dest` tensors and is responsible for freeing them with TF_DeleteTensor.
TF_CAPI_EXPORT extern void TF_GetInputTensorFromVariable(
TF_OpKernelContext* ctx, int input, bool lock_held, bool isVariantType,
bool sparse,
void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
TF_Tensor* dest),
TF_Tensor** out, TF_Status* status);
// This interface forwards the reference from input to the output tensors
// corresponding to the indices provided with `input_index` and `output_index`
TF_CAPI_EXPORT extern void TF_OpKernelContext_ForwardRefInputToRefOutput(
TF_OpKernelContext* ctx, int32_t input_index, int32_t output_index);
// The API releases the opaque lock handle returned with
// `TF_MaybeLockVariableInputMutexesInOrder` API
TF_CAPI_EXPORT extern void TF_ReleaseVariableInputLockHolder(
TF_VariableInputLockHolder* lockHolder);
// Allows plugin to get TF_Tensor when passed its input_name
TF_CAPI_EXPORT extern void TF_GetInputByName(TF_OpKernelContext* ctx,
const char* inputName,
TF_Tensor** tensor,
TF_Status* status);
// Interprets the named kernel construction attribute as a shape attribute and
// fills in `vals` with the size of each dimension. `vals` must point to an
// array of length at least `max_values` (ideally set to total_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size,
// &total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* dims,
size_t num_dims, TF_Status* status);
TF_CAPI_EXPORT extern bool TF_IsRefInput(TF_OpKernelContext* ctx, int i,
TF_Status* status);
#ifndef IS_MOBILE_PLATFORM
// Expose higher level AddN operation for Pluggable vendors to implement
// in the plugin for Variant data types. The API takes in the context and a
// callback provided by pluggable vendor to do a Binary Add operation on the
// tensors unwrapped from the Variant tensors. The caller takes ownership of the
// `a`, `b` and `out` tensors and is responsible for freeing them with
// TF_DeleteTensor.
TF_CAPI_EXPORT extern void TF_AddNVariant(
TF_OpKernelContext* ctx,
void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b,
TF_Tensor* out),
TF_Status* status);
// Expose higher level ZerosLike operation for Pluggable vendors to implement
// in the plugin for Variant data types. The API takes in the context and a
// callback provided by pluggable vendor to do a ZerosLike operation on the
// tensors unwrapped from the Variant tensors. The caller takes ownership of the
// `input` and `out` tensors and is responsible for freeing them with
// TF_DeleteTensor.
TF_CAPI_EXPORT extern void TF_ZerosLikeVariant(
TF_OpKernelContext* ctx,
void (*zeros_like_func)(TF_OpKernelContext* ctx, TF_Tensor* input,
TF_Tensor* out),
TF_Status* status);
typedef struct TF_CoordinationServiceAgent TF_CoordinationServiceAgent;
#endif // IS_MOBILE_PLATFORM
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_