blob: 01b7f65af9b591e7f8117a73c55ca422da117978 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <vector>
#include "pybind11/pybind11.h" // from @pybind11
#include "pybind11/stl.h" // from @pybind11
#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/dtensor/cc/dtensor_device.h"
#include "tensorflow/dtensor/cc/tensor_layout.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
#include "tensorflow/python/util/util.h"
namespace py = ::pybind11;
using tensorflow::dtensor::AddMesh;
using tensorflow::dtensor::AllocateDTensorDevice;
using tensorflow::dtensor::ClearTPUCoreIDs;
using tensorflow::dtensor::ExperimentalClearDefaultLayout;
using tensorflow::dtensor::ExperimentalClearDefaultMesh;
using tensorflow::dtensor::ExperimentalSetDefaultLayout;
using tensorflow::dtensor::ExperimentalSetDefaultMesh;
using tensorflow::dtensor::FetchLayout;
using tensorflow::dtensor::GetFunctionCacheStats;
using tensorflow::dtensor::IsDTensor;
using tensorflow::dtensor::IsSparseDTensor;
using tensorflow::dtensor::Layout;
using tensorflow::dtensor::Mesh;
using tensorflow::dtensor::Pack;
using tensorflow::dtensor::SetIteratorElementLayouts;
using tensorflow::dtensor::SetTPUCoreIDs;
using tensorflow::dtensor::SparsePack;
using tensorflow::dtensor::TPUCoreIDsToLocations;
using tensorflow::dtensor::TPUCoreLocationsToIDs;
using tensorflow::dtensor::Unpack;
void PyXDecref(PyObject* obj) { Py_XDECREF(obj); }
void CallDelete_Device(PyObject* capsule) {
delete reinterpret_cast<TFE_CustomDevice*>(
PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
}
void CallDelete_DeviceInfo(PyObject* capsule) {
void (*destructor)(void*) =
reinterpret_cast<void (*)(void*)>(PyCapsule_GetContext(capsule));
destructor(PyCapsule_GetPointer(capsule, "TFE_CustomDevice_DeviceInfo"));
}
bool CheckResourceVariable(PyObject* item) {
if (tensorflow::swig::IsResourceVariable(item)) {
tensorflow::Safe_PyObjectPtr handle(
PyObject_GetAttrString(item, "_handle"));
return EagerTensor_CheckExact(handle.get());
}
return false;
}
// Supports 2 cases:
// i) input is an EagerTensor.
// ii) input is an arbitrary python list/tuple.
void ConvertToTensor(TFE_Context* ctx, PyObject* input,
tensorflow::Safe_PyObjectPtr* output_handle,
TF_Status* status) {
if (EagerTensor_CheckExact(input)) {
// Input is already a EagerTensor so increment the reference, since the
// caller will use it through output_handle.
Py_INCREF(input);
output_handle->reset(input);
return;
}
if (CheckResourceVariable(input)) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Variable input is not supported.");
return;
}
TFE_TensorHandle* handle =
tensorflow::ConvertToEagerTensor(ctx, input, tensorflow::DT_INVALID);
if (handle == nullptr) {
TF_SetStatus(status, TF_INTERNAL, "Failure converting to eager tensor.");
return;
}
output_handle->reset(EagerTensorFromHandle(handle));
}
PYBIND11_MODULE(_pywrap_dtensor_device, m) {
pybind11_protobuf::ImportNativeProtoCasters();
m.def("Allocate", [](const std::string& name, bool is_async,
int in_flight_nodes_limit) {
TFE_CustomDevice* device = new TFE_CustomDevice;
std::unique_ptr<PyObject, decltype(&PyXDecref)> device_capsule(
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device),
PyXDecref);
void* device_info = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AllocateDTensorDevice(name, device, &device_info, is_async,
in_flight_nodes_limit, status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
std::unique_ptr<PyObject, decltype(&PyXDecref)> device_info_capsule(
PyCapsule_New(device_info, "TFE_CustomDevice_DeviceInfo",
&CallDelete_DeviceInfo),
PyXDecref);
// The PyCapsule destructor needs a pointer to the destructor for
// DeviceInfo.
PyCapsule_SetContext(device_info_capsule.get(),
reinterpret_cast<void*>(device->delete_device));
if (PyErr_Occurred()) throw py::error_already_set();
return pybind11::reinterpret_steal<pybind11::object>(
PyTuple_Pack(2, device_capsule.get(), device_info_capsule.get()));
});
m.def("AddMesh", [](const py::capsule& device_info,
const std::string& serialized_mesh, bool is_host_mesh) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AddMesh(
serialized_mesh,
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
is_host_mesh, status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
});
m.def(
"ExperimentalSetDefaultLayout",
[](const py::capsule& device_info, const std::string& serialized_layout) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
ExperimentalSetDefaultLayout(
serialized_layout,
PyCapsule_GetPointer(device_info.ptr(),
"TFE_CustomDevice_DeviceInfo"),
status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
});
m.def("ExperimentalClearDefaultLayout", [](const py::capsule& device_info) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
ExperimentalClearDefaultLayout(
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
});
m.def("ExperimentalSetDefaultMesh", [](const py::capsule& device_info,
const std::string& serialized_mesh) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
ExperimentalSetDefaultMesh(
serialized_mesh,
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
});
m.def("ExperimentalClearDefaultMesh", [](const py::capsule& device_info) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
ExperimentalClearDefaultMesh(
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
});
m.def("SetTPUCoreIDs", [](const py::capsule& device_info,
const std::string& mesh_name,
const std::vector<int>& tpu_core_ids) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
SetTPUCoreIDs(
mesh_name, tpu_core_ids,
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
});
m.def("ClearTPUCoreIDs", [](const py::capsule& device_info) {
ClearTPUCoreIDs(
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"));
});
m.def("TPUCoreIDsToLocations", [](const py::handle& context,
const py::capsule& device_info,
const std::vector<int>& tpu_core_ids) {
return TPUCoreIDsToLocations(
static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
tpu_core_ids,
PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"));
});
m.def("TPUCoreLocationsToIDs",
[](const py::handle& context, const py::capsule& device_info,
const std::vector<std::vector<int>>& tpu_core_locations) {
return TPUCoreLocationsToIDs(
static_cast<TFE_Context*>(
PyCapsule_GetPointer(context.ptr(), nullptr)),
tpu_core_locations,
PyCapsule_GetPointer(device_info.ptr(),
"TFE_CustomDevice_DeviceInfo"));
});
m.def("Pack", [](const py::handle& context, const py::handle& input_tensors,
const std::string& string_layout,
const py::capsule& device_info, const bool is_sparse) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* ctx =
static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr));
// Convert each python object to safe py eagertensors.
std::vector<tensorflow::Safe_PyObjectPtr> py_eager_tensor_handles;
Py_ssize_t len = PyList_Size(input_tensors.ptr());
py_eager_tensor_handles.resize(len);
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
ConvertToTensor(ctx, elem, &py_eager_tensor_handles[i], status.get());
if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
return tensorflow::PyoOrThrow(nullptr);
}
std::vector<TFE_TensorHandle*> input_vector;
input_vector.resize(len);
for (int i = 0; i < len; ++i)
input_vector[i] = EagerTensor_Handle(py_eager_tensor_handles[i].get());
TFE_TensorHandle* packed_tensor;
if (is_sparse) {
auto size = input_vector.size() / 3;
packed_tensor = SparsePack(
ctx,
/*num_inputs=*/input_vector.size() / 3,
/*indices=*/
std::vector<TFE_TensorHandle*>(input_vector.begin(),
input_vector.begin() + size)
.data(),
/*values=*/
std::vector<TFE_TensorHandle*>(input_vector.begin() + size,
input_vector.begin() + 2 * size)
.data(),
/*shapes=*/
std::vector<TFE_TensorHandle*>(input_vector.begin() + 2 * size,
input_vector.end())
.data(),
string_layout, device_info, status.get());
} else {
packed_tensor = Pack(ctx, input_vector.size(), input_vector.data(),
string_layout, device_info, status.get());
}
if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
return tensorflow::PyoOrThrow(nullptr);
// Convert c++ packed tensor handle into a python eager tensor object.
tensorflow::Safe_PyObjectPtr flat_result(PyList_New(1));
PyList_SET_ITEM(flat_result.get(), 0, EagerTensorFromHandle(packed_tensor));
auto* result = PyList_GET_ITEM(flat_result.get(), 0);
Py_INCREF(result);
return tensorflow::PyoOrThrow(result);
});
m.def("Unpack", [](const py::handle& context,
const py::handle& dtensor_handle,
const py::capsule& device_info) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
if (!EagerTensor_CheckExact(dtensor_handle.ptr())) {
throw py::type_error(absl::StrFormat("Expecting a Tensor, got %s",
py::str(dtensor_handle.get_type())));
}
TFE_TensorHandle* tensor_handle = EagerTensor_Handle(dtensor_handle.ptr());
std::vector<TFE_TensorHandle*> unpacked_handles = Unpack(
static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
tensor_handle, device_info, status.get());
if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
return tensorflow::PyoOrThrow(nullptr);
// Convert all TFE_TensorHandles to py EagerTensor and
// return a python list of them.
int num_outputs = unpacked_handles.size();
PyObject* result(PyList_New(num_outputs));
for (int i = 0; i < num_outputs; ++i) {
PyList_SET_ITEM(result, i, EagerTensorFromHandle(unpacked_handles[i]));
}
return tensorflow::PyoOrThrow(result);
});
m.def(
"FetchLayout",
[](const py::handle& context, const py::handle& dtensor_handle,
const py::capsule& device_info) -> py::object {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
if (!EagerTensor_CheckExact(dtensor_handle.ptr())) {
return py::none();
}
TFE_TensorHandle* tensor_handle =
EagerTensor_Handle(dtensor_handle.ptr());
std::string layout_string =
FetchLayout(static_cast<TFE_Context*>(
PyCapsule_GetPointer(context.ptr(), nullptr)),
tensor_handle, device_info, status.get());
if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
return tensorflow::PyoOrThrow(nullptr);
return tensorflow::PyoOrThrow(
PyUnicode_FromString(layout_string.c_str()));
});
m.def("IsDTensor", [](const py::handle& context,
const py::handle& dtensor_handle,
const py::capsule& device_info) {
if (!EagerTensor_CheckExact(dtensor_handle.ptr())) {
return false;
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* tensor_handle = EagerTensor_Handle(dtensor_handle.ptr());
bool is_dtensor = IsDTensor(
static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
tensor_handle, device_info, status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
return is_dtensor;
});
m.def("IsSparseDTensor", [](const py::handle& context,
const py::handle& dtensor_handle,
const py::capsule& device_info) {
if (!EagerTensor_CheckExact(dtensor_handle.ptr())) {
return false;
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* tensor_handle = EagerTensor_Handle(dtensor_handle.ptr());
bool is_sparse = IsSparseDTensor(
static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
tensor_handle, device_info, status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
throw py::error_already_set();
}
return is_sparse;
});
m.def("GetFunctionCacheStats", [](const py::handle& context,
const py::capsule& device_info) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
return GetFunctionCacheStats(
static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
device_info, status.get());
});
m.def("SetIteratorElementLayouts",
[](const py::handle& context, const py::handle& dtensor_handle,
const std::vector<std::string>& element_layouts,
const py::capsule& device_info) {
if (!EagerTensor_CheckExact(dtensor_handle.ptr())) {
throw py::type_error(
absl::StrFormat("Expecting a Tensor, got %s",
py::str(dtensor_handle.get_type())));
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* tensor_handle =
EagerTensor_Handle(dtensor_handle.ptr());
SetIteratorElementLayouts(
static_cast<TFE_Context*>(
PyCapsule_GetPointer(context.ptr(), nullptr)),
tensor_handle, element_layouts, device_info, status.get());
});
py::class_<Mesh>(m, "Mesh")
.def(py::init(&Mesh::CreateMesh))
.def(py::init([](absl::string_view single_device) {
auto mesh = Mesh::GetSingleDeviceMesh(single_device);
if (!mesh.ok()) {
throw py::value_error(std::string(mesh.status().message()));
}
return *mesh;
}),
py::arg("single_device"), "Creates a single device mesh.")
.def(py::init([](const tensorflow::dtensor::MeshProto& proto) {
auto mesh = Mesh::ParseFromProto(proto);
if (!mesh.ok()) {
throw py::value_error(std::string(mesh.status().message()));
}
return *mesh;
}),
py::arg("mesh_proto"), "Returns a Mesh from a MeshProto.")
.def(py::init([](std::string_view mesh_str) {
auto mesh = Mesh::FromString(mesh_str);
if (!mesh.ok()) {
throw py::value_error(std::string(mesh.status().message()));
}
return *mesh;
}),
py::arg("mesh_str"), "Returns a Mesh from a string.")
.def_property_readonly("name", &Mesh::name)
.def_property_readonly("dim_names", &Mesh::MeshDimNames)
.def_property_readonly("size", &Mesh::num_devices)
.def_property_readonly("single_device", &Mesh::single_device)
.def("__contains__", &Mesh::IsMeshDim, py::arg("dim_name"))
.def("__eq__", &Mesh::operator==)
.def("to_string", &Mesh::ToString,
"Returns string representation of Mesh.")
.def("is_single_device", &Mesh::IsSingleDevice,
"Returns True if the mesh represents a non-distributed device.")
.def("contains_dim", &Mesh::IsMeshDim, py::arg("dim_name"),
"Returns True if a Mesh contains the given dimension name.")
.def(
"dim_size",
[](const Mesh& mesh, std::string_view name) {
auto dim_size = mesh.dim_size(name);
if (!dim_size.ok()) {
throw py::value_error(std::string(dim_size.status().message()));
}
return *dim_size;
},
py::arg("dim_name"), "Returns the size of mesh dimension.")
.def("device_type", &Mesh::device_type,
"Returns the device_type of a Mesh.")
.def("host_mesh", &Mesh::host_mesh,
"Returns a host mesh corresponding to this mesh.")
.def("num_local_devices", &Mesh::num_local_devices,
"Returns the number of local devices.")
.def("min_global_device_id", &Mesh::min_global_device_id,
"Returns the minimum global device ID.")
.def("is_remote", &Mesh::is_remote,
"Returns True if a Mesh contains only remote devices.")
.def("local_device_ids", &Mesh::local_device_ids,
"Returns a list of local device IDs.")
.def("local_devices", &Mesh::local_devices,
"Returns a list of local device specs "
"represented as strings.")
.def("global_device_ids", &Mesh::global_device_ids,
"Returns a list of global device IDs.")
.def("global_devices", &Mesh::global_devices,
"Returns a list of global device specs "
"represented as strings.")
.def("shape", &Mesh::dim_sizes, "Returns the shape of the mesh.")
.def("use_xla_spmd", &Mesh::use_xla_spmd,
"Returns True if Mesh will use XLA for SPMD "
"instead of DTensor SPMD.")
.def(
"as_proto",
[](const Mesh& mesh) {
auto mesh_proto = mesh.ToProto();
if (!mesh_proto.ok()) {
throw py::value_error(std::string(mesh_proto.status().message()));
}
return *mesh_proto;
},
"Returns the MeshProto protobuf message.")
.def("device_location", [](const Mesh& mesh, int device_id) {
auto location = mesh.device_location(device_id);
if (!location.ok()) {
throw py::value_error(std::string(location.status().message()));
}
return std::vector<int64_t>(location->begin(), location->end());
});
py::class_<Layout>(m, "Layout")
.def(py::init([](const std::vector<std::string>& sharding_specs,
const Mesh& mesh) {
auto layout = Layout::GetLayout(sharding_specs, mesh);
if (!layout.ok()) {
throw py::value_error(std::string(layout.status().message()));
}
return *layout;
}),
py::arg("sharding_specs"), py::arg("mesh"))
.def(py::init([](const tensorflow::dtensor::LayoutProto& proto) {
auto layout = Layout::FromProto(proto);
if (!layout.ok()) {
throw py::value_error(std::string(layout.status().message()));
}
return *layout;
}),
py::arg("layout_proto"), "Returns a Layout from a LayoutProto.")
.def(py::init([](std::string_view layout_str) {
auto layout = Layout::FromString(layout_str);
if (!layout.ok()) {
throw py::value_error(std::string(layout.status().message()));
}
return *layout;
}),
py::arg("layout_str"), "Returns a Layout from a string.")
.def(py::init(&Layout::ReplicatedOnMesh), py::arg("mesh"),
py::arg("rank"), "Returns a replicated layout.")
.def(py::init(&Layout::BatchShardedOnMesh), py::arg("mesh"),
py::arg("rank"), py::arg("batch_dim"), py::arg("axis"),
"Returns a batch sharded layout.")
.def(py::init([](const Mesh& mesh) {
auto layout = Layout::GetSingleDeviceLayout(mesh);
if (!layout.ok()) {
throw py::value_error(std::string(layout.status().message()));
}
return *layout;
}),
py::arg("mesh"), "Returns a single device layout.")
.def("__eq__", &Layout::operator==)
.def(
"as_proto",
[](const Layout& layout) {
auto layout_proto = layout.ToProto();
if (!layout_proto.ok()) {
throw py::value_error(
std::string(layout_proto.status().message()));
}
return *layout_proto;
},
"Returns the LayoutProto protobuf message.")
.def("to_string", &Layout::ToString)
.def_property_readonly("sharding_specs", &Layout::sharding_spec_strs)
.def_property_readonly("rank", &Layout::rank)
.def_property_readonly("mesh", &Layout::mesh)
.def("is_fully_replicated", &Layout::IsFullyReplicated,
"Returns True if all tensor axes are replicated.")
.def("is_batch_parallel",
[](const Layout& layout) { return layout.IsBatchParallel(); })
.def("is_single_device", &Layout::IsSingleDevice,
"Returns True if the Layout represents a non-distributed device.")
.def(
"num_shards",
[](const Layout& layout, int dim) {
return layout.num_shards_for_dim(dim);
},
py::arg("idx"),
"Returns the number of shards for tensor dimension `idx`.");
}