blob: 78a4c19ea034e88757e379634f767f611410b4d2 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");;
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "Python.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "pybind11/chrono.h" // from @pybind11
#include "pybind11/complex.h" // from @pybind11
#include "pybind11/functional.h" // from @pybind11
#include "pybind11/pybind11.h" // from @pybind11
#include "pybind11/pytypes.h" // from @pybind11
#include "pybind11/stl.h" // from @pybind11
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/dlpack.h"
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/safe_ptr.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/get_compiler_ir.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
#include "tensorflow/python/util/util.h"
namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(TFE_Executor);
PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge3);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge4);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell);
PYBIND11_MAKE_OPAQUE(TF_DeviceList);
PYBIND11_MAKE_OPAQUE(TF_Function);
PYBIND11_MAKE_OPAQUE(TF_Buffer);
// Eager helper functions migrated from pywrap_tfe.i.
namespace tensorflow {
// We cannot use Context as an opaque type. SWIG also had
// difficult directly passing the pointer around. These
// typemaps are migrated over from pywrap_tfe.i. I tried
// using a custom type caster, but we get segfaults periodically.
// TODO(amitpatankar): Move input and output logic of Context into a
// pybind11 custom type caster.
TFE_Context* InputTFE_Context(const py::handle& ctx) {
return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
}
PyObject* OutputTFE_Context(TFE_Context* context) {
return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
}
TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
// Convert a Python string object to TF_Buffer.
char* c_string;
Py_ssize_t py_size;
// PyBytes_AsStringAndSize() does not copy but simply interprets the input
if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
throw py::error_already_set();
}
return TF_NewBufferFromString(static_cast<void*>(c_string),
static_cast<size_t>(py_size));
}
// These functions are typemaps from the Python side. I did not use
// a custom type caster since the logic is slightly harder to follow. This
// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
TFE_InputTensorHandles InputTFE_InputTensorHandles(
const py::handle& input_tensors) {
TFE_InputTensorHandles input_tensor_handles;
if (input_tensors.ptr() != Py_None) {
if (!PyList_Check(input_tensors.ptr())) {
tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
}
Py_ssize_t len = PyList_Size(input_tensors.ptr());
input_tensor_handles.resize(len);
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
if (!elem) {
tensorflow::ThrowTypeError("Input Tensor does not exist.");
}
if (EagerTensor_CheckExact(elem)) {
(input_tensor_handles)[i] = EagerTensor_Handle(elem);
} else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
// Use equivalent of object.__getattribute__ to get the underlying
// tf wrapped EagerTensor (if there is one).
tensorflow::Safe_PyObjectPtr tf_should_use_attr(
#if PY_MAJOR_VERSION < 3
PyString_InternFromString("_tf_should_use_wrapped_value")
#else
PyUnicode_InternFromString("_tf_should_use_wrapped_value")
#endif
);
tensorflow::Safe_PyObjectPtr value_attr(
PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
if (value_attr) {
// This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
(input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
} else {
// This is a subclass of EagerTensor that we don't support.
PyErr_Clear();
tensorflow::ThrowTypeError(
tensorflow::strings::StrCat(
"Saw an object that is an instance of a strict subclass of "
"EagerTensor, which is not supported. Item ",
i, " is type: ", elem->ob_type->tp_name)
.c_str());
}
} else if (tensorflow::swig::IsTensor(elem)) {
// If it isnt an EagerTensor, but is still a Tensor, it must be a graph
// tensor.
tensorflow::Safe_PyObjectPtr py_tensor_repr(PyObject_Repr(elem));
std::string tensor_repr =
py_tensor_repr ? TFE_GetPythonString(py_tensor_repr.get())
: "<unknown>";
tensorflow::Safe_PyObjectPtr py_op(PyObject_GetAttrString(elem, "op"));
tensorflow::Safe_PyObjectPtr py_defined_graph(
PyObject_GetAttrString(py_op.get(), "graph"));
tensorflow::Safe_PyObjectPtr py_defined_graph_str(
PyObject_Str(py_defined_graph.get()));
std::string defined_graph_str =
py_defined_graph_str
? TFE_GetPythonString(py_defined_graph_str.get())
: "<unknown>";
tensorflow::Safe_PyObjectPtr c_op(
PyObject_GetAttrString(py_op.get(), "_c_op"));
auto& node = py::cast<TF_Operation*>(c_op.get())->node;
auto node_name_str = node.name();
std::string frame_str, traceback_str;
if (auto stack_trace = node.GetStackTrace()) {
auto frame = stack_trace->LastUserFrame();
frame_str =
absl::StrFormat("File \"%s\", line %d, in %s", frame.file_name,
frame.line_number, frame.function_name);
auto stack_trace_list =
absl::StrSplit(stack_trace->ToString({true}), '\n');
traceback_str = absl::StrJoin(
stack_trace_list, "", [&](std::string* out, const auto line) {
absl::StrAppend(out, " ", line, "\n");
});
} else {
frame_str = "<unknown>";
traceback_str = "<unknown>\n";
}
// Keep in sync with func_graph.py.
// TODO(b/200991648): Unify those two paths.
tensorflow::ThrowTypeError(
tensorflow::strings::StrCat(
tensor_repr,
" is out of scope and cannot be used here. "
"Use return values, explicit Python locals or TensorFlow "
"collections to access it.\n"
"Please see https://www.tensorflow.org/guide/"
"function#all_outputs_of_a_tffunction_must_be_return_values "
"for more information.\n\n",
tensor_repr, " was defined here:\n", traceback_str,
"\nThe tensor ", tensor_repr,
" cannot be accessed from here, because it was "
"defined in ",
defined_graph_str, ", which is out of scope.")
.c_str());
} else {
tensorflow::ThrowTypeError(
tensorflow::strings::StrCat(
"provided list of inputs contains objects other "
"than 'EagerTensor'. Item ",
i, " is type: ", elem->ob_type->tp_name)
.c_str());
}
}
}
return input_tensor_handles;
}
// These functions are typemaps from the Python side. I did not use
// a custom type caster since the logic is slightly harder to follow. This
// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
// This function actually takes a number rather than an output Tensor holder.
TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
const py::handle& num_outputs) {
TFE_OutputTensorHandles output_tensor_handles;
#if PY_MAJOR_VERSION < 3
if (!PyInt_Check(num_outputs.ptr())) {
#else
if (!PyLong_Check(num_outputs.ptr())) {
#endif
PyErr_SetString(PyExc_TypeError,
"expected an integer value (size of the number of "
"outputs of the operation)");
throw py::error_already_set();
}
#if PY_MAJOR_VERSION < 3
long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT
#else
long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT
#endif
// PyLong_AsLong might throw an error if an overflow occurs.
if (PyErr_Occurred()) {
PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
"Number of outputs is too big: ", sz)
.c_str());
throw py::error_already_set();
}
// We can't handle more than int32 sizes for number of outputs.
if (static_cast<long>(static_cast<int32_t>(sz)) != sz) { // NOLINT
PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
"Number of outputs is too big: ", sz)
.c_str());
throw py::error_already_set();
}
if (sz > 0) {
#if PY_MAJOR_VERSION < 3
output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
#else
output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
#endif
}
return output_tensor_handles;
}
tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) {
auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
tensorflow::InputTFE_Context(ctx));
tensorflow::DeviceNameUtils::ParsedName input_device_name;
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name,
&input_device_name)) {
tensorflow::ThrowValueError(
absl::StrFormat("Failed parsing device name: '%s'. Note a valid device "
"string should at least contain a device type and a "
"device index, like \"GPU:0\".",
device_name)
.c_str());
}
std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
tensorflow::Device* matched_device = nullptr;
for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
tensorflow::Device* device = devices[device_idx];
if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
input_device_name, device->parsed_name())) {
if (matched_device != nullptr) {
tensorflow::ThrowValueError(
absl::StrFormat("Multiple devices match the provided string "
"'%s': '%s' and '%s'.",
device_name, matched_device->name(), device->name())
.c_str());
}
matched_device = device;
}
}
if (matched_device == nullptr) {
tensorflow::ThrowValueError(
absl::StrFormat("No matching devices found for '%s'", device_name)
.c_str());
}
return matched_device;
}
// Packs multiple `EagerTensor`s of the same dtype and shape into one
// `EagerTensor`.
py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
const py::handle& tensors) {
TFE_Context* ctx = tensorflow::InputTFE_Context(context);
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
int size = handles.size();
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
PyObject* packed_tensor =
EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
return tensorflow::PyoOrThrow(packed_tensor);
}
// This function was created from fusing the typemap logic in platform/base.i.
py::object TFE_Py_ExecuteCancelable_wrapper(
const py::handle& context, const char* device_name, const char* op_name,
const py::handle& inputs, const py::handle& attrs,
tensorflow::CancellationManager* cancellation_manager,
const py::handle& num_outputs) {
TFE_Context* ctx = tensorflow::InputTFE_Context(context);
TFE_InputTensorHandles input_tensor_handles =
InputTFE_InputTensorHandles(inputs);
TFE_OutputTensorHandles output_tensor_handles =
InputTFE_OutputTensorHandles(num_outputs);
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
attrs.ptr(), tensorflow::wrap(cancellation_manager),
&output_tensor_handles, status.get());
int output_len = output_tensor_handles.size();
PyObject* output_list = PyList_New(output_len);
for (int i = 0; i < output_len; ++i) {
PyObject* output;
output = EagerTensorFromHandle(output_tensor_handles.at(i));
PyList_SetItem(output_list, i, output);
}
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return tensorflow::PyoOrThrow(output_list);
}
static py::object TF_ListPhysicalDevices() {
std::vector<string> devices;
tensorflow::Status s =
tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
MaybeRaiseRegisteredFromStatus(s);
PyObject* result = PyList_New(devices.size());
int i = 0;
for (auto& dev : devices) {
PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
PyList_SetItem(result, i, dev_obj);
++i;
}
return tensorflow::PyoOrThrow(result);
}
static py::object TF_ListPluggablePhysicalDevices() {
std::vector<string> devices;
tensorflow::Status s =
tensorflow::DeviceFactory::ListPluggablePhysicalDevices(&devices);
MaybeRaiseRegisteredFromStatus(s);
Safe_PyObjectPtr result(PyList_New(devices.size()));
int i = 0;
for (auto& dev : devices) {
PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
PyList_SetItem(result.get(), i, dev_obj);
++i;
}
return tensorflow::PyoOrThrow(result.release());
}
static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
std::unordered_map<string, string> device_details;
tensorflow::Status s =
tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
tensorflow::Set_TF_Status_from_Status(status.get(), s);
MaybeRaiseRegisteredFromTFStatus(status.get());
return device_details;
}
static py::object TFE_ClearScalarCache() {
tensorflow::TFE_TensorHandleCache::Get()->Clear();
return py::none();
}
// Returns compiler IR for a given function.
static py::bytes TFE_GetCompilerIr(py::handle& ctx,
const char* concrete_function_name,
const char* stage, const char* device_name,
py::handle& flat_arg_inputs,
py::handle& captured_inputs) {
EagerContext* context = ContextFromInterface(
reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
std::string s_stage(stage);
IrExportStage selected_stage = [&] {
if (s_stage == "hlo") {
return IrExportStage::HLO;
} else if (s_stage == "hlo_no_metadata") {
return IrExportStage::HLO_NO_METADATA;
} else if (s_stage == "hlo_serialized") {
return IrExportStage::HLO_SERIALIZED;
} else if (s_stage == "optimized_hlo") {
return IrExportStage::OPTIMIZED_HLO;
} else if (s_stage == "optimized_hlo_serialized") {
return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
} else if (s_stage == "optimized_hlo_proto_serialized") {
return IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED;
} else if (s_stage == "optimized_hlo_dot") {
return IrExportStage::OPTIMIZED_HLO_DOT;
} else {
ThrowValueError(
absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
"'hlo', 'hlo_serialized', 'optimized_hlo', "
"'optimized_hlo_serialized', 'optimized_hlo_dot'",
s_stage)
.c_str());
}
}();
CompilerArgSource compiler_arg_source = [&] {
if (PyList_Size(flat_arg_inputs.ptr()) == 0) {
return CompilerArgSource::CONCRETE_INPUT;
}
PyObject* elem = PyList_GetItem(flat_arg_inputs.ptr(), 0);
if (swig::IsTensorSpec(elem)) {
return CompilerArgSource::TENSOR_SPEC;
} else if (swig::IsTensor(elem)) {
return CompilerArgSource::CONCRETE_INPUT;
} else {
ThrowValueError(
tensorflow::strings::StrCat(
"Only accept tf.TensorSpec or tf.Tensor but got type ",
elem->ob_type->tp_name)
.c_str());
}
}();
Py_ssize_t flat_arg_len = PyList_Size(flat_arg_inputs.ptr());
Py_ssize_t captured_input_len = PyList_Size(captured_inputs.ptr());
std::vector<ArgShapeAndDType> flat_args;
std::vector<const TensorHandle*> captured_input_handles;
if (compiler_arg_source == CompilerArgSource::TENSOR_SPEC) {
flat_args.resize(flat_arg_len);
captured_input_handles.reserve(captured_input_len);
for (Py_ssize_t i = 0; i < flat_arg_len; ++i) {
PyObject* elem_ptr = PyList_GetItem(flat_arg_inputs.ptr(), i);
py::object elem = py::reinterpret_borrow<py::object>(elem_ptr);
py::object py_dtype = elem.attr("dtype");
py::object py_shape = elem.attr("shape");
int dtype = py::cast<int>(py_dtype.attr("_type_enum"));
auto shape = py::cast<std::vector<int64_t>>(py_shape);
flat_args[i].dtype = DataType(dtype);
flat_args[i].shape = TensorShape(shape);
}
} else if (compiler_arg_source == CompilerArgSource::CONCRETE_INPUT) {
captured_input_handles.reserve(flat_arg_len + captured_input_len);
TFE_InputTensorHandles handles =
InputTFE_InputTensorHandles(flat_arg_inputs);
for (TFE_TensorHandle* tensor_handle : handles) {
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
captured_input_handles.push_back(
TensorHandleFromInterface(abstract_tensor_handle));
}
}
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(captured_inputs);
for (TFE_TensorHandle* tensor_handle : handles) {
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
captured_input_handles.push_back(
TensorHandleFromInterface(abstract_tensor_handle));
}
DeviceNameUtils::ParsedName input_device_name;
if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
ThrowValueError(
absl::StrFormat("Failed parsing device name: '%s'", device_name)
.c_str());
}
std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
d->parsed_name());
});
if (selected_device == devices.end()) {
ThrowValueError(
absl::StrFormat("No matching device found for '%s'", device_name)
.c_str());
}
StatusOr<std::string> hlo_str = GetCompilerIr(
selected_stage, context->pflr(), concrete_function_name, *selected_device,
context, flat_args, captured_input_handles, compiler_arg_source);
if (!hlo_str.ok()) {
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
hlo_str.status().message())
.c_str());
}
return py::bytes(*hlo_str);
}
} // namespace tensorflow
namespace {
// Wrapper around the EagerContextThreadLocalData struct (defined in
// pywrap_tfe.h), so it can be accessed from Python.
//
// For PyObject* fields, the get_*() methods return a new reference; and the
// set_*() methods create a new reference (i.e., they do not steal a reference).
class EagerContextThreadLocalDataWrapper {
public:
explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
py::handle is_eager,
py::handle device_spec)
: py_eager_context_(py_eager_context.ptr()) {
tensorflow::MakeEagerContextThreadLocalData(
py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
}
~EagerContextThreadLocalDataWrapper() {
tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
}
bool get_is_eager() const { return GetData()->is_eager; }
void set_is_eager(bool v) { GetData()->is_eager = v; }
bool get_invoking_op_callbacks() const {
return GetData()->invoking_op_callbacks;
}
void set_invoking_op_callbacks(bool v) {
GetData()->invoking_op_callbacks = v;
}
py::object get_device_name() const {
return GetPyObject(&GetData()->device_name);
}
void set_device_name(py::handle v) {
SetPyObject(v, &GetData()->device_name);
}
py::object get_scope_name() const {
return GetPyObject(&GetData()->scope_name);
}
void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
py::object get_device_spec() const {
return GetPyObject(&GetData()->device_spec);
}
void set_device_spec(py::handle v) {
SetPyObject(v, &GetData()->device_spec);
}
py::object get_function_call_options() const {
return GetPyObject(&GetData()->function_call_options);
}
void set_function_call_options(py::handle v) {
SetPyObject(v, &GetData()->function_call_options);
}
py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
py::object get_op_callbacks() const {
return GetPyObject(&GetData()->op_callbacks);
}
void set_op_callbacks(py::handle v) {
SetPyObject(v, &GetData()->op_callbacks);
}
private:
tensorflow::EagerContextThreadLocalData* GetData() const {
auto* result =
tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
if (!result) {
throw py::error_already_set();
}
return result;
}
py::object GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
return pybind11::reinterpret_borrow<py::object>(obj->get());
}
void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
Py_INCREF(value.ptr());
ptr->reset(value.ptr());
}
PyObject* py_eager_context_; // not owned (borrowed reference).
};
} // namespace
// py::return_value_policy::reference is defined as specified by the
// pybind11 documents listed here.
// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
// This means that C++ maintains ownership of the object. We
// are only assigning this to functions that return opaque types.
PYBIND11_MODULE(_pywrap_tfe, m) {
py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
"TFE_ContextOptions");
py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
m, "TFE_MonitoringCounter0");
py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
m, "TFE_MonitoringCounter1");
py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
m, "TFE_MonitoringCounter2");
py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
m, "TFE_MonitoringStringGauge0");
py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
m, "TFE_MonitoringStringGauge1");
py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
m, "TFE_MonitoringStringGauge2");
py::class_<TFE_MonitoringStringGauge3> TFE_MonitoringStringGauge3_class(
m, "TFE_MonitoringStringGauge3");
py::class_<TFE_MonitoringStringGauge4> TFE_MonitoringStringGauge4_class(
m, "TFE_MonitoringStringGauge4");
py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
m, "TFE_MonitoringIntGauge0");
py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
m, "TFE_MonitoringIntGauge1");
py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
m, "TFE_MonitoringIntGauge2");
py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
m, "TFE_MonitoringBoolGauge0");
py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
m, "TFE_MonitoringBoolGauge1");
py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
m, "TFE_MonitoringBoolGauge2");
py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
m, "TFE_MonitoringCounterCell");
py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
m, "TFE_MonitoringIntGaugeCell");
py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
m, "TFE_MonitoringStringGaugeCell");
py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
m, "TFE_MonitoringBoolGaugeCell");
py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
m, "TFE_MonitoringSamplerCell");
py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
m, "TFE_MonitoringBuckets");
py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
m, "TFE_MonitoringSampler0");
py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
m, "TFE_MonitoringSampler1");
py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
m, "TFE_MonitoringSampler2");
py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
m, "TFE_CancellationManager");
py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
py::class_<TF_Function> TF_Function_class(m, "TF_Function");
m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
});
m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
return tensorflow::PyoOrThrow(
TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
});
m.def("TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
tensorflow::Device* matched_device =
tensorflow::GetMatchedDevice(ctx, device_name);
tensorflow::AllocatorAttributes attrs;
tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
if (absl::optional<tensorflow::AllocatorStats> stats =
allocator->GetStats()) {
return std::map<std::string, int64_t>{{"current", stats->bytes_in_use},
{"peak", stats->peak_bytes_in_use}};
}
tensorflow::ThrowValueError(
absl::StrFormat("Allocator stats not available for device '%s'",
device_name)
.c_str());
});
m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) {
tensorflow::Device* matched_device =
tensorflow::GetMatchedDevice(ctx, device_name);
tensorflow::AllocatorAttributes attrs;
tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
if (!allocator->ClearStats()) {
tensorflow::ThrowValueError(
absl::StrFormat("Cannot reset memory stats for device '%s'",
device_name)
.c_str());
}
});
// XLA Eager Logic
m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {
// Since python protobuf enums are integers, cast to an integer before
// returning the enum to python.
return static_cast<int32_t>(
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
});
m.def("TF_EnableMlirBridge", [](bool enabled) {
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
enabled
? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
: tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_DISABLED;
});
m.def("TF_EnableXlaDevices", [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
});
m.def("TF_ResetJitCompilerFlags",
[] { tensorflow::ResetJitCompilerFlags(); });
// TFE_Context Logic
m.def(
"TFE_NewContext",
[](const TFE_ContextOptions* opts) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_Context* context = TFE_NewContext(opts, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
},
py::return_value_policy::reference);
m.def("TFE_DeleteContext", [](py::handle& o) {
TFE_DeleteContext(tensorflow::InputTFE_Context(o));
});
m.def(
"TFE_ContextListDevices",
[](py::handle& o) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def(
"TFE_SetLogicalCpuDevices",
[](py::handle& ctx, int num_cpus, const char* prefix) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_SetLogicalCpuDevices(tensorflow::InputTFE_Context(ctx), num_cpus,
prefix, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
},
py::return_value_policy::reference);
m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
});
m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextAddFunctionDef",
[](py::handle& ctx, const char* serialized_function_def, size_t size) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
serialized_function_def, size,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def(
"TFE_ContextGetFunction",
[](py::handle& ctx, const char* function_name) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TF_Function* tf_function = TFE_ContextGetFunction(
tensorflow::InputTFE_Context(ctx), function_name, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return tf_function;
},
py::return_value_policy::reference);
m.def("TFE_ContextGetFunctionDef",
[](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
function_name, &buf, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextGetGraphDebugInfo",
[](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextGetGraphDebugInfo(tensorflow::InputTFE_Context(ctx),
function_name, &buf, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
->ListFunctionNames();
});
m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
});
m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
});
m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
});
m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
});
m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextClearCaches", [](py::handle& o) {
TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
});
m.def("TFE_GetContextId", [](py::handle& ctx) {
return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
});
m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
return TFE_ContextGetDevicePlacementPolicy(
tensorflow::InputTFE_Context(ctx));
});
m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
[](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
TFE_ContextSetThreadLocalDevicePlacementPolicy(
tensorflow::InputTFE_Context(ctx), policy);
});
m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
py::bytes proto) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
tensorflow::Safe_TF_BufferPtr buf =
tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
buf.get()->data, buf.get()->length, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
py::bytes proto) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
tensorflow::Safe_TF_BufferPtr buf =
tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
Py_BEGIN_ALLOW_THREADS;
TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
keep_alive_secs, buf.get()->data,
buf.get()->length, status.get());
Py_END_ALLOW_THREADS;
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
worker_name, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
// NOTE: release Python GIL for pending PyFunc ops to be executed properly.
Py_BEGIN_ALLOW_THREADS;
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
Py_END_ALLOW_THREADS;
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
// NOTE: release Python GIL for pending PyFunc ops to be executed properly.
Py_BEGIN_ALLOW_THREADS;
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
Py_END_ALLOW_THREADS;
// NOTE: different from TFE_ContextSyncExecutors that raises potential
// errors, deliberately ignore executor statuses in cleanup.
});
m.def(
"TFE_InsertConfigKeyValue",
[](py::handle& ctx, const char* config_key, const char* config_value) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
Py_BEGIN_ALLOW_THREADS;
TFE_InsertConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
config_value, status.get());
Py_END_ALLOW_THREADS;
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
},
py::return_value_policy::reference);
m.def(
"TFE_GetConfigKeyValue",
[](py::handle& ctx, const char* config_key, int64_t timeout_in_ms,
TF_Buffer& config_value) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
Py_BEGIN_ALLOW_THREADS;
TFE_GetConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
timeout_in_ms, &config_value, status.get());
Py_END_ALLOW_THREADS;
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
},
py::return_value_policy::reference);
m.def(
"TFE_DeleteConfigKeyValue",
[](py::handle& ctx, const char* config_key) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
Py_BEGIN_ALLOW_THREADS;
TFE_DeleteConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
status.get());
Py_END_ALLOW_THREADS;
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
},
py::return_value_policy::reference);
m.def(
"TFE_ReportErrorToCluster",
[](py::handle& ctx, int error_code, const char* error_message) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ReportErrorToCluster(tensorflow::InputTFE_Context(ctx), error_code,
error_message, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
},
py::return_value_policy::reference);
m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
status.get());
});
m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
status.get());
});
m.def("TFE_ContextSetRunEagerOpAsFunction", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetRunEagerOpAsFunction(tensorflow::InputTFE_Context(ctx),
enable, status.get());
});
m.def("TFE_ContextSetJitCompileRewrite", [](py::handle& ctx, bool enable) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextSetJitCompileRewrite(tensorflow::InputTFE_Context(ctx), enable,
status.get());
});
m.def("TFE_GetTaskStates", [](py::handle& ctx,
const std::vector<std::string>& job_names,
const std::vector<int>& task_nums) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (job_names.size() != task_nums.size()) {
status->status = tensorflow::errors::InvalidArgument(
"The size of job names is not equal to the size of task nums.");
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
std::vector<tensorflow::CoordinatedTask> coordinated_tasks;
for (size_t i = 0; i < job_names.size(); ++i) {
for (size_t j = 0; j < task_nums[i]; ++j) {
auto& coordinated_task = coordinated_tasks.emplace_back();
coordinated_task.set_job_name(job_names[i]);
coordinated_task.set_task_id(j);
}
}
size_t task_len = coordinated_tasks.size();
auto state = std::make_unique<TF_Status[]>(task_len);
TF_Buffer tasks;
tasks.data = coordinated_tasks.data();
tasks.length = task_len;
TFE_GetTaskStates(tensorflow::InputTFE_Context(ctx), tasks, state.get(),
status.get());
py::list output(task_len);
for (size_t i = 0; i < task_len; ++i) {
auto code = TF_GetCode(&state[i]);
if (code != TF_Code::TF_OK) {
py::dict payloads;
for (const auto& payload :
tensorflow::errors::GetPayloads(state[i].status)) {
payloads[payload.first.c_str()] = payload.second;
}
auto exception_class = py::reinterpret_steal<py::object>(
tensorflow::PyExceptionRegistry::Lookup(code));
if (!exception_class) {
status->status = tensorflow::errors::Internal(absl::StrCat(
"Fail to find the corresponding exception class for ", code));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
output[i] = exception_class(py::none(), py::none(),
TF_Message(&state[i]), payloads);
} else {
output[i] = py::none();
}
}
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return tensorflow::PyoOrThrow(output.release().ptr());
});
m.def("TFE_WaitAtBarrier",
[](py::handle& ctx, const char* barrier_id, int64_t timeout_in_ms) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_WaitAtBarrier(tensorflow::InputTFE_Context(ctx), barrier_id,
timeout_in_ms, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
// TFE_Executor logic
m.def(
"TFE_NewExecutor",
[](const bool is_async, const bool enable_streaming_enqueue,
const int in_flight_nodes_limit) {
TFE_Executor* exc = TFE_NewExecutor(is_async, enable_streaming_enqueue,
in_flight_nodes_limit);
return exc;
},
py::return_value_policy::reference);
m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
// NOTE: release Python GIL for pending PyFunc ops to be executed properly.
Py_BEGIN_ALLOW_THREADS;
TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
Py_END_ALLOW_THREADS;
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
TFE_Executor& exc) {
TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
});
m.def(
"TFE_ContextGetExecutorForThread",
[](py::handle& o) {
return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
},
py::return_value_policy::reference);
m.def("TFE_OpNameGetAttrType",
[](py::handle& ctx, const char* op_or_function_name,
const char* attr_name) {
int temp = 0;
unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
op_or_function_name, attr_name,
is_list, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
#if PY_MAJOR_VERSION < 3
PyObject* output_pyo = PyInt_FromLong(output);
#else
PyObject* output_pyo = PyLong_FromLong(output);
#endif
if (*is_list == 1) {
PyObject* list = PyList_New(1);
PyList_SetItem(list, 0, output_pyo);
return tensorflow::PyoOrThrow(list);
}
return tensorflow::PyoOrThrow(output_pyo);
});
m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
});
m.def("TFE_Py_PackEagerTensors",
[](const py::handle& context, const py::handle& handles) {
return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
});
m.def("TFE_Py_SetEagerTensorProfiler", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_SetEagerTensorProfiler(o.ptr()));
});
m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
});
m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
});
m.def("TFE_Py_Execute",
[](const py::handle& context, const char* device_name,
const char* op_name, const py::handle& inputs,
const py::handle& attrs, const py::handle& num_outputs) {
return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
context, device_name, op_name, inputs, attrs.ptr(), nullptr,
num_outputs);
});
m.def(
"TFE_Py_ExecuteCancelable",
[](const py::handle& context, const char* device_name,
const char* op_name, const py::handle& inputs, const py::handle& attrs,
tensorflow::CancellationManager& cancellation_manager,
const py::handle& num_outputs) {
return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
context, device_name, op_name, inputs, attrs.ptr(),
&cancellation_manager, num_outputs);
});
m.def("TFE_Py_FastPathExecute", [](const py::args args) {
// TFE_Py_FastPathExecute requires error checking prior to returning.
return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
});
m.def("TFE_Py_RecordGradient",
[](const py::handle& op_name, const py::handle& inputs,
const py::handle& attrs, const py::handle& results,
const py::handle& forward_pass_name_scope) {
return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
forward_pass_name_scope.ptr()));
});
m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
// TFE_Py_Tape Logic
m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
const py::handle& watch_accessed_variables) {
return tensorflow::PyoOrThrow(
TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
});
m.def("TFE_Py_TapeSetAdd",
[](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
m.def("TFE_Py_TapeSetRemove",
[](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
m.def("TFE_Py_TapeSetIsStopped",
[]() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
m.def("TFE_Py_TapeSetIsEmpty",
[]() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
return tensorflow::PyoOrThrow(
TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
});
m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
return tensorflow::PyoOrThrow(
TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
});
m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
m.def("TFE_Py_TapeSetRecordOperation",
[](const py::handle& op_type, const py::handle& output_tensors,
const py::handle& input_tensors, const py::handle& backward_function,
const py::handle& forward_function) {
return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
backward_function.ptr(), forward_function.ptr()));
});
m.def(
"TFE_Py_TapeSetRecordOperationBackprop",
[](const py::handle& op_type, const py::handle& output_tensors,
const py::handle& input_tensors, const py::handle& backward_function) {
return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
backward_function.ptr()));
});
m.def(
"TFE_Py_TapeSetRecordOperationForwardprop",
[](const py::handle& op_type, const py::handle& output_tensors,
const py::handle& input_tensors, const py::handle& backward_function,
const py::handle& forwardprop_output_indices) {
return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
backward_function.ptr(), forwardprop_output_indices.ptr()));
});
m.def("TFE_Py_TapeGradient",
[](const py::handle& tape, const py::handle& target,
const py::handle& sources, const py::handle& output_gradients,
const py::handle& sources_raw,
const py::handle& unconnected_gradients) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
PyObject* output = TFE_Py_TapeGradient(
tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return tensorflow::PyoOrThrow(output);
});
m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
TFE_Py_TapeVariableAccessed(variable.ptr());
});
m.def("TFE_Py_TapeWatch",
[](const py::handle& tape, const py::handle& tensor) {
TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
});
m.def("TFE_Py_TapeWatchVariable",
[](const py::handle& tape, const py::handle& variable) {
TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
});
m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
});
// TFE_Py_VariableWatcher logic.
m.def("TFE_Py_VariableWatcherNew",
[]() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
});
m.def("TFE_Py_VariableWatcherVariableAccessed",
[](const py::handle& variable) {
TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
});
m.def("TFE_Py_VariableWatcherWatchedVariables",
[](const py::handle& variable_watcher) {
return tensorflow::PyoOrThrow(
TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
});
// TFE_Py_ForwardAccumulator logic.
m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
});
m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
return tensorflow::PyoOrThrow(
TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
});
m.def("TFE_Py_ForwardAccumulatorSetRemove",
[](const py::handle& accumulator) {
TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
});
m.def("TFE_Py_ForwardAccumulatorWatch",
[](const py::handle& accumulator, const py::handle& tensor,
const py::handle& tangent) {
TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
tangent.ptr());
});
m.def("TFE_Py_ForwardAccumulatorJVP",
[](const py::handle& accumulator, const py::handle& tensor) {
return tensorflow::PyoOrThrow(
TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
});
m.def("TFE_Py_ForwardAccumulatorPushState", []() {
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
});
m.def("TFE_Py_ForwardAccumulatorPopState", []() {
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
});
m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
});
// TFE_ContextOptions Logic
m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
py::return_value_policy::reference);
m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
py::bytes proto) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
tensorflow::Safe_TF_BufferPtr buf =
tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
&TFE_ContextOptionsSetDevicePlacementPolicy);
m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
// Experimental feature, intentionally not exposed as a C API yet.
m.def("TFE_ContextOptionsSetRunEagerOpAsFunction",
[](TFE_ContextOptions* options, bool run_eager_op_as_function) {
options->run_eager_op_as_function = run_eager_op_as_function;
});
m.def("TFE_ContextOptionsSetJitCompileRewrite",
[](TFE_ContextOptions* options, bool jit_compile_rewrite) {
options->jit_compile_rewrite = jit_compile_rewrite;
});
m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
py::return_value_policy::reference);
// TFE_Py_TensorShape Logic
m.def("TFE_Py_TensorShapeSlice",
[](const py::handle& tensors, int slice_dim) {
return tensorflow::PyoOrThrow(
TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
});
m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
int slice_dim) {
return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
});
m.def("TFE_Py_EnableInteractivePythonLogging",
&TFE_Py_EnableInteractivePythonLogging);
// Additional Context Logic
m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
});
m.def("TFE_Py_SetCEagerContext", [](const py::handle& ctx) {
// TODO(mdan): This cast might need rewriting to ImmediateExecutionContext.
if (ctx.is_none()) {
tensorflow::SetCEagerContext(nullptr);
} else {
tensorflow::SetCEagerContext(reinterpret_cast<tensorflow::EagerContext*>(
tensorflow::InputTFE_Context(ctx)));
}
});
m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
});
m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
tensorflow::Safe_TF_BufferPtr buf =
tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
buf.get()->length, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
const char* message) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
});
m.def("TFE_CollectiveOpsCheckPeerHealth",
[](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
task, timeout_in_ms, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
m.def("TF_ListPluggablePhysicalDevices",
&tensorflow::TF_ListPluggablePhysicalDevices);
m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
py::return_value_policy::reference);
m.def("TF_DeviceListCount", &TF_DeviceListCount);
m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TF_DeviceListName(list, index, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TF_DeviceListType(list, index, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
// TFE_MonitoringCounter Logic
m.def("TFE_MonitoringCounterCellIncrementBy",
&TFE_MonitoringCounterCellIncrementBy);
m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
m.def(
"TFE_MonitoringNewCounter0",
[](const char* name, const char* description) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_MonitoringNewCounter0(name, status.get(), description);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewCounter1",
[](const char* name, const char* description, const char* label1) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_MonitoringNewCounter1(name, status.get(), description, label1);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewCounter2",
[](const char* name, const char* description, const char* label1,
const char* label2) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
label1, label2);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
py::return_value_policy::reference);
// TFE_MonitoringIntGauge Logic
m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
m.def(
"TFE_MonitoringNewIntGauge0",
[](const char* name, const char* description) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_MonitoringNewIntGauge0(name, status.get(), description);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewIntGauge1",
[](const char* name, const char* description, const char* label1) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewIntGauge2",
[](const char* name, const char* description, const char* label1,
const char* label2) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
description, label1, label2);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
py::return_value_policy::reference);
m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
m.def("TFE_MonitoringStringGaugeCellValue",
&TFE_MonitoringStringGaugeCellValue);
m.def(
"TFE_MonitoringNewStringGauge0",
[](const char* name, const char* description) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_MonitoringNewStringGauge0(name, status.get(), description);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
// TFE_MonitoringStringGauge Logic
m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewStringGauge1",
[](const char* name, const char* description, const char* label1) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
description, label1);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewStringGauge2",
[](const char* name, const char* description, const char* label1,
const char* label2) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewStringGauge2(
name, status.get(), description, label1, label2);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewStringGauge3",
[](const char* name, const char* description, const char* label1,
const char* label2, const char* label3) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewStringGauge3(
name, status.get(), description, label1, label2, label3);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteStringGauge3", &TFE_MonitoringDeleteStringGauge3);
m.def("TFE_MonitoringGetCellStringGauge3", &TFE_MonitoringGetCellStringGauge3,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewStringGauge4",
[](const char* name, const char* description, const char* label1,
const char* label2, const char* label3, const char* label4) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewStringGauge4(
name, status.get(), description, label1, label2, label3, label4);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteStringGauge4", &TFE_MonitoringDeleteStringGauge4);
m.def("TFE_MonitoringGetCellStringGauge4", &TFE_MonitoringGetCellStringGauge4,
py::return_value_policy::reference);
// TFE_MonitoringBoolGauge Logic
m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
m.def(
"TFE_MonitoringNewBoolGauge0",
[](const char* name, const char* description) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_MonitoringNewBoolGauge0(name, status.get(), description);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewBoolGauge1",
[](const char* name, const char* description, const char* label1) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
description, label1);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewBoolGauge2",
[](const char* name, const char* description, const char* label1,
const char* label2) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
description, label1, label2);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
py::return_value_policy::reference);
// TFE_MonitoringSampler Logic
m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
m.def("TFE_MonitoringNewExponentialBuckets",
&TFE_MonitoringNewExponentialBuckets,
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewSampler0",
[](const char* name, TFE_MonitoringBuckets* buckets,
const char* description) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output =
TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewSampler1",
[](const char* name, TFE_MonitoringBuckets* buckets,
const char* description, const char* label1) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
description, label1);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
py::return_value_policy::reference);
m.def(
"TFE_MonitoringNewSampler2",
[](const char* name, TFE_MonitoringBuckets* buckets,
const char* description, const char* label1, const char* label2) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
description, label1, label2);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
},
py::return_value_policy::reference);
m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
py::return_value_policy::reference);
m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
py::return_value_policy::reference);
// TFE_CancellationManager Logic
m.def("TFE_NewCancellationManager",
[]() { return new tensorflow::CancellationManager(); });
m.def("TFE_CancellationManagerIsCancelled",
&tensorflow::CancellationManager::IsCancelled);
m.def("TFE_CancellationManagerStartCancel",
&tensorflow::CancellationManager::StartCancel);
m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
// Util buffer helper functions
m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
py::return_value_policy::reference);
// DLPack functions
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
PyObject* eager_tensor_pyobject_ptr = o.ptr();
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
status->status = tensorflow::errors::InvalidArgument(
"The argument to `to_dlpack` must be a TF tensor, not Python object");
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
py::capsule capsule(
dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
void* dlm_rptr =
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
if (dlm_rptr) {
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
PyCapsule_SetDestructor(capsule, nullptr);
}
}
});
return capsule;
});
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
const py::handle& context) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (absl::string_view(pycapsule.name()) !=
tensorflow::kDlTensorCapsuleName) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
"Note that a DLPack tensor may be consumed at most once.",
absl::string_view(pycapsule.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
pycapsule, status.get(), tensorflow::InputTFE_Context(context));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
PyObject* pyhandle = EagerTensorFromHandle(thandle);
return tensorflow::PyoOrThrow(pyhandle);
});
m.def("TFE_Py_IsCustomDevice",
[](const py::handle& context, const char* device_name) {
return TFE_IsCustomDevice(tensorflow::InputTFE_Context(context),
device_name);
});
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
const py::capsule& device,
const char* device_name,
const py::capsule& device_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (absl::string_view(device.name()) != "TFE_CustomDevice") {
status->status = tensorflow::errors::InvalidArgument(
"Expected a capsule named 'TFE_CustomDevice' for the `device` "
"argument, got ",
absl::string_view(device.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
if (absl::string_view(device_info.name()) !=
"TFE_CustomDevice_DeviceInfo") {
status->status = tensorflow::errors::InvalidArgument(
"Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
"the `device_info` argument, got ",
absl::string_view(device_info.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
// TFE_RegisterCustomDevice takes ownership
PyCapsule_SetDestructor(device_info.ptr(), nullptr);
TFE_RegisterCustomDevice(
tensorflow::InputTFE_Context(context),
*reinterpret_cast<TFE_CustomDevice*>(
PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
device_name,
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
py::class_<EagerContextThreadLocalDataWrapper>(m,
"EagerContextThreadLocalData")
.def(py::init<py::handle, py::handle, py::handle>(),
py::arg("py_eager_context"), py::arg("is_eager"),
py::arg("device_spec"))
.def_property("is_eager",
&EagerContextThreadLocalDataWrapper::get_is_eager,
&EagerContextThreadLocalDataWrapper::set_is_eager)
.def_property(
"invoking_op_callbacks",
&EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
&EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
.def_property("device_name",
&EagerContextThreadLocalDataWrapper::get_device_name,
&EagerContextThreadLocalDataWrapper::set_device_name)
.def_property("scope_name",
&EagerContextThreadLocalDataWrapper::get_scope_name,
&EagerContextThreadLocalDataWrapper::set_scope_name)
.def_property("device_spec",
&EagerContextThreadLocalDataWrapper::get_device_spec,
&EagerContextThreadLocalDataWrapper::set_device_spec)
.def_property(
"function_call_options",
&EagerContextThreadLocalDataWrapper::get_function_call_options,
&EagerContextThreadLocalDataWrapper::set_function_call_options)
.def_property("executor",
&EagerContextThreadLocalDataWrapper::get_executor,
&EagerContextThreadLocalDataWrapper::set_executor)
.def_property("op_callbacks",
&EagerContextThreadLocalDataWrapper::get_op_callbacks,
&EagerContextThreadLocalDataWrapper::set_op_callbacks);
// C API Enum
py::enum_<TFE_ContextDevicePlacementPolicy>(
m, "TFE_ContextDevicePlacementPolicy")
.value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
.value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
.value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
.value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
.export_values();
py::enum_<TF_AttrType>(m, "TF_AttrType")
.value("TF_ATTR_STRING", TF_ATTR_STRING)
.value("TF_ATTR_INT", TF_ATTR_INT)
.value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
.value("TF_ATTR_BOOL", TF_ATTR_BOOL)
.value("TF_ATTR_TYPE", TF_ATTR_TYPE)
.value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
.value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
.value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
.value("TF_ATTR_FUNC", TF_ATTR_FUNC)
.export_values();
};