xref: /aosp_15_r20/external/tensorflow/tensorflow/python/pywrap_dtensor_device.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include "pybind11/pybind11.h"
20 #include "pybind11/stl.h"
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/dtensor/cc/dtensor_device.h"
23 #include "tensorflow/python/eager/pywrap_tensor.h"
24 #include "tensorflow/python/eager/pywrap_tfe.h"
25 #include "tensorflow/python/lib/core/pybind11_lib.h"
26 #include "tensorflow/python/lib/core/pybind11_status.h"
27 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28 #include "tensorflow/python/util/util.h"
29 
30 namespace py = ::pybind11;
31 using tensorflow::dtensor::AddMesh;
32 using tensorflow::dtensor::AllocateDTensorDevice;
33 using tensorflow::dtensor::ClearTPUCoreIDs;
34 using tensorflow::dtensor::ExperimentalClearDefaultLayout;
35 using tensorflow::dtensor::ExperimentalClearDefaultMesh;
36 using tensorflow::dtensor::ExperimentalSetDefaultLayout;
37 using tensorflow::dtensor::ExperimentalSetDefaultMesh;
38 using tensorflow::dtensor::FetchLayout;
39 using tensorflow::dtensor::GetFunctionCacheHitAndMissCount;
40 using tensorflow::dtensor::IsSparseDTensor;
41 using tensorflow::dtensor::Pack;
42 using tensorflow::dtensor::SetSameShapePolicy;
43 using tensorflow::dtensor::SetTPUCoreIDs;
44 using tensorflow::dtensor::SparsePack;
45 using tensorflow::dtensor::TPUCoreIDsToLocations;
46 using tensorflow::dtensor::TPUCoreLocationsToIDs;
47 using tensorflow::dtensor::Unpack;
48 
PyXDecref(PyObject * obj)49 void PyXDecref(PyObject* obj) { Py_XDECREF(obj); }
50 
CallDelete_Device(PyObject * capsule)51 void CallDelete_Device(PyObject* capsule) {
52   delete reinterpret_cast<TFE_CustomDevice*>(
53       PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
54 }
55 
CallDelete_DeviceInfo(PyObject * capsule)56 void CallDelete_DeviceInfo(PyObject* capsule) {
57   void (*destructor)(void*) =
58       reinterpret_cast<void (*)(void*)>(PyCapsule_GetContext(capsule));
59   destructor(PyCapsule_GetPointer(capsule, "TFE_CustomDevice_DeviceInfo"));
60 }
61 
62 // Supports 2 cases:
63 //  i) input is an EagerTensor.
64 //  ii) input is an arbitrary python list/tuple.
ConvertToTensor(TFE_Context * ctx,PyObject * input,tensorflow::Safe_PyObjectPtr * output_handle,TF_Status * status)65 void ConvertToTensor(TFE_Context* ctx, PyObject* input,
66                      tensorflow::Safe_PyObjectPtr* output_handle,
67                      TF_Status* status) {
68   if (EagerTensor_CheckExact(input)) {
69     // Input is already a EagerTensor so increment the reference, since the
70     // caller will use it through output_handle.
71     Py_INCREF(input);
72     output_handle->reset(input);
73     return;
74   }
75   TFE_TensorHandle* handle =
76       tensorflow::ConvertToEagerTensor(ctx, input, tensorflow::DT_INVALID);
77   if (handle == nullptr) {
78     TF_SetStatus(status, TF_INTERNAL, "Failure converting to eager tensor.");
79     return;
80   }
81   output_handle->reset(EagerTensorFromHandle(handle));
82 }
83 
PYBIND11_MODULE(_pywrap_dtensor_device,m)84 PYBIND11_MODULE(_pywrap_dtensor_device, m) {
85   m.def("Allocate", [](const std::string& name) {
86     TFE_CustomDevice* device = new TFE_CustomDevice;
87     std::unique_ptr<PyObject, decltype(&PyXDecref)> device_capsule(
88         PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device),
89         PyXDecref);
90     void* device_info;
91     AllocateDTensorDevice(name, device, &device_info);
92     std::unique_ptr<PyObject, decltype(&PyXDecref)> device_info_capsule(
93         PyCapsule_New(device_info, "TFE_CustomDevice_DeviceInfo",
94                       &CallDelete_DeviceInfo),
95         PyXDecref);
96     // The PyCapsule destructor needs a pointer to the destructor for
97     // DeviceInfo.
98     PyCapsule_SetContext(device_info_capsule.get(),
99                          reinterpret_cast<void*>(device->delete_device));
100     if (PyErr_Occurred()) throw py::error_already_set();
101     return pybind11::reinterpret_steal<pybind11::object>(
102         PyTuple_Pack(2, device_capsule.get(), device_info_capsule.get()));
103   });
104   m.def("AddMesh", [](const py::capsule& device_info,
105                       const std::string& serialized_mesh, bool is_async,
106                       bool is_host_mesh) {
107     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
108         TF_NewStatus(), TF_DeleteStatus);
109     AddMesh(
110         serialized_mesh,
111         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
112         is_async, is_host_mesh, status.get());
113     if (TF_GetCode(status.get()) != TF_OK) {
114       PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
115       throw py::error_already_set();
116     }
117   });
118   m.def(
119       "ExperimentalSetDefaultLayout",
120       [](const py::capsule& device_info, const std::string& serialized_layout) {
121         std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
122             TF_NewStatus(), TF_DeleteStatus);
123         ExperimentalSetDefaultLayout(
124             serialized_layout,
125             PyCapsule_GetPointer(device_info.ptr(),
126                                  "TFE_CustomDevice_DeviceInfo"),
127             status.get());
128         if (TF_GetCode(status.get()) != TF_OK) {
129           PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
130           throw py::error_already_set();
131         }
132       });
133   m.def("ExperimentalClearDefaultLayout", [](const py::capsule& device_info) {
134     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
135         TF_NewStatus(), TF_DeleteStatus);
136     ExperimentalClearDefaultLayout(
137         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
138         status.get());
139     if (TF_GetCode(status.get()) != TF_OK) {
140       PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
141       throw py::error_already_set();
142     }
143   });
144   m.def("ExperimentalSetDefaultMesh", [](const py::capsule& device_info,
145                                          const std::string& serialized_mesh) {
146     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
147         TF_NewStatus(), TF_DeleteStatus);
148     ExperimentalSetDefaultMesh(
149         serialized_mesh,
150         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
151         status.get());
152     if (TF_GetCode(status.get()) != TF_OK) {
153       PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
154       throw py::error_already_set();
155     }
156   });
157   m.def("ExperimentalClearDefaultMesh", [](const py::capsule& device_info) {
158     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
159         TF_NewStatus(), TF_DeleteStatus);
160     ExperimentalClearDefaultMesh(
161         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
162         status.get());
163     if (TF_GetCode(status.get()) != TF_OK) {
164       PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
165       throw py::error_already_set();
166     }
167   });
168   m.def("SetSameShapePolicy", [](const py::capsule& device_info, bool enabled) {
169     SetSameShapePolicy(
170         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
171         enabled);
172   });
173   m.def("SetTPUCoreIDs", [](const py::capsule& device_info,
174                             const std::string& mesh_name,
175                             const std::vector<int>& tpu_core_ids) {
176     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
177         TF_NewStatus(), TF_DeleteStatus);
178     SetTPUCoreIDs(
179         mesh_name, tpu_core_ids,
180         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
181         status.get());
182     if (TF_GetCode(status.get()) != TF_OK) {
183       PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
184       throw py::error_already_set();
185     }
186   });
187   m.def("ClearTPUCoreIDs", [](const py::capsule& device_info) {
188     ClearTPUCoreIDs(
189         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"));
190   });
191   m.def("TPUCoreIDsToLocations", [](const py::handle& context,
192                                     const py::capsule& device_info,
193                                     const std::vector<int>& tpu_core_ids) {
194     return TPUCoreIDsToLocations(
195         static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
196         tpu_core_ids,
197         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"));
198   });
199   m.def("TPUCoreLocationsToIDs",
200         [](const py::handle& context, const py::capsule& device_info,
201            const std::vector<std::vector<int>>& tpu_core_locations) {
202           return TPUCoreLocationsToIDs(
203               static_cast<TFE_Context*>(
204                   PyCapsule_GetPointer(context.ptr(), nullptr)),
205               tpu_core_locations,
206               PyCapsule_GetPointer(device_info.ptr(),
207                                    "TFE_CustomDevice_DeviceInfo"));
208         });
209   m.def("Pack", [](const py::handle& context, const py::handle& input_tensors,
210                    const std::string& string_layout,
211                    const py::capsule& device_info, const bool is_sparse) {
212     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
213         TF_NewStatus(), TF_DeleteStatus);
214     TFE_Context* ctx =
215         static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr));
216     // Convert each python object to safe py eagertensors.
217     std::vector<tensorflow::Safe_PyObjectPtr> py_eager_tensor_handles;
218     Py_ssize_t len = PyList_Size(input_tensors.ptr());
219     py_eager_tensor_handles.resize(len);
220 
221     for (Py_ssize_t i = 0; i < len; ++i) {
222       PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
223       ConvertToTensor(ctx, elem, &py_eager_tensor_handles[i], status.get());
224 
225       if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
226         return tensorflow::PyoOrThrow(nullptr);
227     }
228     std::vector<TFE_TensorHandle*> input_vector;
229     input_vector.resize(len);
230     for (int i = 0; i < len; ++i)
231       input_vector[i] = EagerTensor_Handle(py_eager_tensor_handles[i].get());
232     TFE_TensorHandle* packed_tensor;
233     if (is_sparse) {
234       auto size = input_vector.size() / 3;
235       packed_tensor = SparsePack(
236           ctx,
237           /*num_inputs=*/input_vector.size() / 3,
238           /*indices=*/
239           std::vector<TFE_TensorHandle*>(input_vector.begin(),
240                                          input_vector.begin() + size)
241               .data(),
242           /*values=*/
243           std::vector<TFE_TensorHandle*>(input_vector.begin() + size,
244                                          input_vector.begin() + 2 * size)
245               .data(),
246           /*shapes=*/
247           std::vector<TFE_TensorHandle*>(input_vector.begin() + 2 * size,
248                                          input_vector.end())
249               .data(),
250           string_layout, device_info, status.get());
251     } else {
252       packed_tensor = Pack(ctx, input_vector.size(), input_vector.data(),
253                            string_layout, device_info, status.get());
254     }
255     if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
256       return tensorflow::PyoOrThrow(nullptr);
257     // Convert c++ packed tensor handle into a python eager tensor object.
258     tensorflow::Safe_PyObjectPtr flat_result(PyList_New(1));
259     PyList_SET_ITEM(flat_result.get(), 0, EagerTensorFromHandle(packed_tensor));
260     auto* result = PyList_GET_ITEM(flat_result.get(), 0);
261     Py_INCREF(result);
262     return tensorflow::PyoOrThrow(result);
263   });
264   m.def("Unpack", [](const py::handle& context,
265                      const py::handle& dtensor_handle,
266                      const py::capsule& device_info) {
267     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
268         TF_NewStatus(), TF_DeleteStatus);
269 
270     TFE_TensorHandle* input_handle = EagerTensor_Handle(dtensor_handle.ptr());
271     std::vector<TFE_TensorHandle*> unpacked_handles = Unpack(
272         static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
273         input_handle, device_info, status.get());
274 
275     if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
276       return tensorflow::PyoOrThrow(nullptr);
277     // Convert all TFE_TensorHandles to py EagerTensor and
278     // return a python list of them.
279     int num_outputs = unpacked_handles.size();
280     PyObject* result(PyList_New(num_outputs));
281     for (int i = 0; i < num_outputs; ++i) {
282       PyList_SET_ITEM(result, i, EagerTensorFromHandle(unpacked_handles[i]));
283     }
284     return tensorflow::PyoOrThrow(result);
285   });
286   m.def(
287       "FetchLayout",
288       [](const py::handle& context, const py::handle& dtensor_handle,
289          const py::capsule& device_info) -> py::object {
290         std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
291             TF_NewStatus(), TF_DeleteStatus);
292 
293         std::string layout_string =
294             FetchLayout(static_cast<TFE_Context*>(
295                             PyCapsule_GetPointer(context.ptr(), nullptr)),
296                         EagerTensor_Handle(dtensor_handle.ptr()), device_info,
297                         status.get());
298         if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
299           return tensorflow::PyoOrThrow(nullptr);
300         return tensorflow::PyoOrThrow(
301             PyUnicode_FromString(layout_string.c_str()));
302       });
303   m.def("IsSparseDTensor", [](const py::handle& context,
304                               const py::handle& dtensor_handle,
305                               const py::capsule& device_info) {
306     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
307         TF_NewStatus(), TF_DeleteStatus);
308 
309     TFE_TensorHandle* input_handle = EagerTensor_Handle(dtensor_handle.ptr());
310     bool is_sparse = IsSparseDTensor(
311         static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
312         input_handle, device_info, status.get());
313 
314     if (TF_GetCode(status.get()) != TF_OK) {
315       PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
316       throw py::error_already_set();
317     }
318     return is_sparse;
319   });
320   m.def("GetFunctionCacheHitAndMissCount", [](const py::handle& context,
321                                               const py::capsule& device_info) {
322     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
323         TF_NewStatus(), TF_DeleteStatus);
324     return GetFunctionCacheHitAndMissCount(
325         static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
326         device_info, status.get());
327   });
328 }
329