xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tfe_wrapper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");;
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "Python.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_split.h"
22 #include "pybind11/chrono.h"
23 #include "pybind11/complex.h"
24 #include "pybind11/functional.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "pybind11/stl.h"
28 #include "tensorflow/c/c_api.h"
29 #include "tensorflow/c/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api.h"
31 #include "tensorflow/c/eager/c_api_experimental.h"
32 #include "tensorflow/c/eager/c_api_internal.h"
33 #include "tensorflow/c/eager/dlpack.h"
34 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
35 #include "tensorflow/c/eager/tfe_context_internal.h"
36 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
37 #include "tensorflow/c/tf_status.h"
38 #include "tensorflow/c/tf_status_helper.h"
39 #include "tensorflow/compiler/jit/flags.h"
40 #include "tensorflow/compiler/jit/get_compiler_ir.h"
41 #include "tensorflow/core/common_runtime/eager/context.h"
42 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
43 #include "tensorflow/python/eager/pywrap_tfe.h"
44 #include "tensorflow/python/lib/core/py_exception_registry.h"
45 #include "tensorflow/python/lib/core/pybind11_lib.h"
46 #include "tensorflow/python/lib/core/pybind11_status.h"
47 #include "tensorflow/python/lib/core/safe_ptr.h"
48 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
49 #include "tensorflow/python/util/util.h"
50 
51 namespace py = pybind11;
52 
53 PYBIND11_MAKE_OPAQUE(TFE_Executor);
54 PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
55 PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
56 
57 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
58 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
59 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2);
60 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0);
61 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1);
62 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2);
63 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge3);
64 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge4);
65 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0);
66 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1);
67 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2);
68 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0);
69 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1);
70 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2);
71 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0);
72 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1);
73 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2);
74 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell);
75 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell);
76 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell);
77 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell);
78 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell);
79 
80 PYBIND11_MAKE_OPAQUE(TF_DeviceList);
81 PYBIND11_MAKE_OPAQUE(TF_Function);
82 PYBIND11_MAKE_OPAQUE(TF_Buffer);
83 
84 // Eager helper functions migrated from pywrap_tfe.i.
85 
86 namespace tensorflow {
87 
88 // We cannot use Context as an opaque type. SWIG also had
89 // difficult directly passing the pointer around. These
90 // typemaps are migrated over from pywrap_tfe.i. I tried
91 // using a custom type caster, but we get segfaults periodically.
92 
93 // TODO(amitpatankar): Move input and output logic of Context into a
94 // pybind11 custom type caster.
95 
InputTFE_Context(const py::handle & ctx)96 TFE_Context* InputTFE_Context(const py::handle& ctx) {
97   return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
98 }
99 
OutputTFE_Context(TFE_Context * context)100 PyObject* OutputTFE_Context(TFE_Context* context) {
101   return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
102 }
103 
ProtoStringToTFBuffer(PyObject * input)104 TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
105   // Convert a Python string object to TF_Buffer.
106   char* c_string;
107   Py_ssize_t py_size;
108   // PyBytes_AsStringAndSize() does not copy but simply interprets the input
109   if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
110     // Python has raised an error (likely TypeError or UnicodeEncodeError).
111     throw py::error_already_set();
112   }
113   return TF_NewBufferFromString(static_cast<void*>(c_string),
114                                 static_cast<size_t>(py_size));
115 }
116 
117 // These functions are typemaps from the Python side. I did not use
118 // a custom type caster since the logic is slightly harder to follow. This
119 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
InputTFE_InputTensorHandles(const py::handle & input_tensors)120 TFE_InputTensorHandles InputTFE_InputTensorHandles(
121     const py::handle& input_tensors) {
122   TFE_InputTensorHandles input_tensor_handles;
123   if (input_tensors.ptr() != Py_None) {
124     if (!PyList_Check(input_tensors.ptr())) {
125       tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
126     }
127     Py_ssize_t len = PyList_Size(input_tensors.ptr());
128     input_tensor_handles.resize(len);
129     for (Py_ssize_t i = 0; i < len; ++i) {
130       PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
131       if (!elem) {
132         tensorflow::ThrowTypeError("Input Tensor does not exist.");
133       }
134       if (EagerTensor_CheckExact(elem)) {
135         (input_tensor_handles)[i] = EagerTensor_Handle(elem);
136       } else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
137         // Use equivalent of object.__getattribute__ to get the underlying
138         // tf wrapped EagerTensor (if there is one).
139         tensorflow::Safe_PyObjectPtr tf_should_use_attr(
140 #if PY_MAJOR_VERSION < 3
141             PyString_InternFromString("_tf_should_use_wrapped_value")
142 #else
143             PyUnicode_InternFromString("_tf_should_use_wrapped_value")
144 #endif
145         );
146         tensorflow::Safe_PyObjectPtr value_attr(
147             PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
148         if (value_attr) {
149           // This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
150           (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
151         } else {
152           // This is a subclass of EagerTensor that we don't support.
153           PyErr_Clear();
154           tensorflow::ThrowTypeError(
155               tensorflow::strings::StrCat(
156                   "Saw an object that is an instance of a strict subclass of "
157                   "EagerTensor, which is not supported.  Item ",
158                   i, " is type: ", elem->ob_type->tp_name)
159                   .c_str());
160         }
161       } else if (tensorflow::swig::IsTensor(elem)) {
162         // If it isnt an EagerTensor, but is still a Tensor, it must be a graph
163         // tensor.
164         tensorflow::Safe_PyObjectPtr py_tensor_repr(PyObject_Repr(elem));
165         std::string tensor_repr =
166             py_tensor_repr ? TFE_GetPythonString(py_tensor_repr.get())
167                            : "<unknown>";
168         tensorflow::Safe_PyObjectPtr py_op(PyObject_GetAttrString(elem, "op"));
169         tensorflow::Safe_PyObjectPtr py_defined_graph(
170             PyObject_GetAttrString(py_op.get(), "graph"));
171         tensorflow::Safe_PyObjectPtr py_defined_graph_str(
172             PyObject_Str(py_defined_graph.get()));
173         std::string defined_graph_str =
174             py_defined_graph_str
175                 ? TFE_GetPythonString(py_defined_graph_str.get())
176                 : "<unknown>";
177         tensorflow::Safe_PyObjectPtr c_op(
178             PyObject_GetAttrString(py_op.get(), "_c_op"));
179         auto& node = py::cast<TF_Operation*>(c_op.get())->node;
180         auto node_name_str = node.name();
181         std::string frame_str, traceback_str;
182         if (auto stack_trace = node.GetStackTrace()) {
183           auto frame = stack_trace->LastUserFrame();
184           frame_str =
185               absl::StrFormat("File \"%s\", line %d, in %s", frame.file_name,
186                               frame.line_number, frame.function_name);
187           auto stack_trace_list =
188               absl::StrSplit(stack_trace->ToString({true}), '\n');
189           traceback_str = absl::StrJoin(
190               stack_trace_list, "", [&](std::string* out, const auto line) {
191                 absl::StrAppend(out, "    ", line, "\n");
192               });
193         } else {
194           frame_str = "<unknown>";
195           traceback_str = "<unknown>\n";
196         }
197         // Keep in sync with func_graph.py.
198         // TODO(b/200991648): Unify those two paths.
199         tensorflow::ThrowTypeError(
200             tensorflow::strings::StrCat(
201                 tensor_repr,
202                 " is out of scope and cannot be used here. "
203                 "Use return values, explicit Python locals or TensorFlow "
204                 "collections to access it.\n"
205                 "Please see https://www.tensorflow.org/guide/"
206                 "function#all_outputs_of_a_tffunction_must_be_return_values "
207                 "for more information.\n\n",
208                 tensor_repr, " was defined here:\n", traceback_str,
209                 "\nThe tensor ", tensor_repr,
210                 " cannot be accessed from here, because it was "
211                 "defined in ",
212                 defined_graph_str, ", which is out of scope.")
213                 .c_str());
214       } else {
215         tensorflow::ThrowTypeError(
216             tensorflow::strings::StrCat(
217                 "provided list of inputs contains objects other "
218                 "than 'EagerTensor'. Item ",
219                 i, " is type: ", elem->ob_type->tp_name)
220                 .c_str());
221       }
222     }
223   }
224   return input_tensor_handles;
225 }
226 
227 // These functions are typemaps from the Python side. I did not use
228 // a custom type caster since the logic is slightly harder to follow. This
229 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
230 // This function actually takes a number rather than an output Tensor holder.
InputTFE_OutputTensorHandles(const py::handle & num_outputs)231 TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
232     const py::handle& num_outputs) {
233   TFE_OutputTensorHandles output_tensor_handles;
234 #if PY_MAJOR_VERSION < 3
235   if (!PyInt_Check(num_outputs.ptr())) {
236 #else
237   if (!PyLong_Check(num_outputs.ptr())) {
238 #endif
239     PyErr_SetString(PyExc_TypeError,
240                     "expected an integer value (size of the number of "
241                     "outputs of the operation)");
242     throw py::error_already_set();
243   }
244 #if PY_MAJOR_VERSION < 3
245   long sz = PyInt_AsLong(num_outputs.ptr());  // NOLINT
246 #else
247   long sz = PyLong_AsLong(num_outputs.ptr());  // NOLINT
248 #endif
249   // PyLong_AsLong might throw an error if an overflow occurs.
250   if (PyErr_Occurred()) {
251     PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
252                                           "Number of outputs is too big: ", sz)
253                                           .c_str());
254     throw py::error_already_set();
255   }
256   // We can't handle more than int32 sizes for number of outputs.
257   if (static_cast<long>(static_cast<int32_t>(sz)) != sz) {  // NOLINT
258     PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
259                                           "Number of outputs is too big: ", sz)
260                                           .c_str());
261     throw py::error_already_set();
262   }
263   if (sz > 0) {
264 #if PY_MAJOR_VERSION < 3
265     output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
266 #else
267     output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
268 #endif
269   }
270   return output_tensor_handles;
271 }
272 
273 tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) {
274   auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
275       tensorflow::InputTFE_Context(ctx));
276 
277   tensorflow::DeviceNameUtils::ParsedName input_device_name;
278   if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name,
279                                                          &input_device_name)) {
280     tensorflow::ThrowValueError(
281         absl::StrFormat("Failed parsing device name: '%s'. Note a valid device "
282                         "string should at least contain a device type and a "
283                         "device index, like \"GPU:0\".",
284                         device_name)
285             .c_str());
286   }
287 
288   std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
289 
290   tensorflow::Device* matched_device = nullptr;
291   for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
292     tensorflow::Device* device = devices[device_idx];
293 
294     if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
295             input_device_name, device->parsed_name())) {
296       if (matched_device != nullptr) {
297         tensorflow::ThrowValueError(
298             absl::StrFormat("Multiple devices match the provided string "
299                             "'%s': '%s' and '%s'.",
300                             device_name, matched_device->name(), device->name())
301                 .c_str());
302       }
303       matched_device = device;
304     }
305   }
306 
307   if (matched_device == nullptr) {
308     tensorflow::ThrowValueError(
309         absl::StrFormat("No matching devices found for '%s'", device_name)
310             .c_str());
311   }
312 
313   return matched_device;
314 }
315 
316 // Packs multiple `EagerTensor`s of the same dtype and shape into one
317 // `EagerTensor`.
318 py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
319                                            const py::handle& tensors) {
320   TFE_Context* ctx = tensorflow::InputTFE_Context(context);
321   TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
322   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
323   int size = handles.size();
324   TFE_TensorHandle* packed_handle =
325       TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
326   tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
327   PyObject* packed_tensor =
328       EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
329   return tensorflow::PyoOrThrow(packed_tensor);
330 }
331 
332 // This function was created from fusing the typemap logic in platform/base.i.
333 py::object TFE_Py_ExecuteCancelable_wrapper(
334     const py::handle& context, const char* device_name, const char* op_name,
335     const py::handle& inputs, const py::handle& attrs,
336     tensorflow::CancellationManager* cancellation_manager,
337     const py::handle& num_outputs) {
338   TFE_Context* ctx = tensorflow::InputTFE_Context(context);
339   TFE_InputTensorHandles input_tensor_handles =
340       InputTFE_InputTensorHandles(inputs);
341   TFE_OutputTensorHandles output_tensor_handles =
342       InputTFE_OutputTensorHandles(num_outputs);
343   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
344   TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
345                            attrs.ptr(), tensorflow::wrap(cancellation_manager),
346                            &output_tensor_handles, status.get());
347 
348   int output_len = output_tensor_handles.size();
349   PyObject* output_list = PyList_New(output_len);
350   for (int i = 0; i < output_len; ++i) {
351     PyObject* output;
352     output = EagerTensorFromHandle(output_tensor_handles.at(i));
353     PyList_SetItem(output_list, i, output);
354   }
355   tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
356   return tensorflow::PyoOrThrow(output_list);
357 }
358 
359 static py::object TF_ListPhysicalDevices() {
360   std::vector<string> devices;
361   tensorflow::Status s =
362       tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
363   MaybeRaiseRegisteredFromStatus(s);
364   PyObject* result = PyList_New(devices.size());
365   int i = 0;
366   for (auto& dev : devices) {
367     PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
368     PyList_SetItem(result, i, dev_obj);
369     ++i;
370   }
371   return tensorflow::PyoOrThrow(result);
372 }
373 
374 static py::object TF_ListPluggablePhysicalDevices() {
375   std::vector<string> devices;
376   tensorflow::Status s =
377       tensorflow::DeviceFactory::ListPluggablePhysicalDevices(&devices);
378   MaybeRaiseRegisteredFromStatus(s);
379   Safe_PyObjectPtr result(PyList_New(devices.size()));
380   int i = 0;
381   for (auto& dev : devices) {
382     PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
383     PyList_SetItem(result.get(), i, dev_obj);
384     ++i;
385   }
386   return tensorflow::PyoOrThrow(result.release());
387 }
388 
389 static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
390   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
391   std::unordered_map<string, string> device_details;
392   tensorflow::Status s =
393       tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
394   tensorflow::Set_TF_Status_from_Status(status.get(), s);
395   MaybeRaiseRegisteredFromTFStatus(status.get());
396   return device_details;
397 }
398 
399 static py::object TFE_ClearScalarCache() {
400   tensorflow::TFE_TensorHandleCache::Get()->Clear();
401   return py::none();
402 }
403 
404 // Returns compiler IR for a given function.
405 static py::bytes TFE_GetCompilerIr(py::handle& ctx,
406                                    const char* concrete_function_name,
407                                    const char* stage, const char* device_name,
408                                    py::handle& inputs) {
409   EagerContext* context = ContextFromInterface(
410       reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
411 
412   std::string s_stage(stage);
413   IrExportStage selected_stage = [&] {
414     if (s_stage == "hlo") {
415       return IrExportStage::HLO;
416     } else if (s_stage == "hlo_no_metadata") {
417       return IrExportStage::HLO_NO_METADATA;
418     } else if (s_stage == "hlo_serialized") {
419       return IrExportStage::HLO_SERIALIZED;
420     } else if (s_stage == "optimized_hlo") {
421       return IrExportStage::OPTIMIZED_HLO;
422     } else if (s_stage == "optimized_hlo_serialized") {
423       return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
424     } else if (s_stage == "optimized_hlo_proto_serialized") {
425       return IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED;
426     } else if (s_stage == "optimized_hlo_dot") {
427       return IrExportStage::OPTIMIZED_HLO_DOT;
428     } else {
429       ThrowValueError(
430           absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
431                           "'hlo', 'hlo_serialized', 'optimized_hlo', "
432                           "'optimized_hlo_serialized', 'optimized_hlo_dot'",
433                           s_stage)
434               .c_str());
435     }
436   }();
437 
438   TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
439 
440   std::vector<const TensorHandle*> input_handles;
441   for (TFE_TensorHandle* tensor_handle : handles) {
442     AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
443     input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
444   }
445 
446   DeviceNameUtils::ParsedName input_device_name;
447   if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
448     ThrowValueError(
449         absl::StrFormat("Failed parsing device name: '%s'", device_name)
450             .c_str());
451   }
452 
453   std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
454   auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
455     return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
456                                                   d->parsed_name());
457   });
458   if (selected_device == devices.end()) {
459     ThrowValueError(
460         absl::StrFormat("No matching device found for '%s'", device_name)
461             .c_str());
462   }
463 
464   xla::StatusOr<std::string> hlo_str =
465       GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
466                     *selected_device, context, input_handles);
467 
468   if (!hlo_str.ok()) {
469     ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
470                                     hlo_str.status().error_message())
471                         .c_str());
472   }
473   return py::bytes(*hlo_str);
474 }
475 
476 }  // namespace tensorflow
477 
478 namespace {
479 
480 // Wrapper around the EagerContextThreadLocalData struct (defined in
481 // pywrap_tfe.h), so it can be accessed from Python.
482 //
483 // For PyObject* fields, the get_*() methods return a new reference; and the
484 // set_*() methods create a new reference (i.e., they do not steal a reference).
485 class EagerContextThreadLocalDataWrapper {
486  public:
EagerContextThreadLocalDataWrapper(py::handle py_eager_context,py::handle is_eager,py::handle device_spec)487   explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
488                                               py::handle is_eager,
489                                               py::handle device_spec)
490       : py_eager_context_(py_eager_context.ptr()) {
491     tensorflow::MakeEagerContextThreadLocalData(
492         py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
493   }
494 
~EagerContextThreadLocalDataWrapper()495   ~EagerContextThreadLocalDataWrapper() {
496     tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
497   }
498 
get_is_eager() const499   bool get_is_eager() const { return GetData()->is_eager; }
set_is_eager(bool v)500   void set_is_eager(bool v) { GetData()->is_eager = v; }
501 
get_invoking_op_callbacks() const502   bool get_invoking_op_callbacks() const {
503     return GetData()->invoking_op_callbacks;
504   }
set_invoking_op_callbacks(bool v)505   void set_invoking_op_callbacks(bool v) {
506     GetData()->invoking_op_callbacks = v;
507   }
508 
get_device_name() const509   py::object get_device_name() const {
510     return GetPyObject(&GetData()->device_name);
511   }
set_device_name(py::handle v)512   void set_device_name(py::handle v) {
513     SetPyObject(v, &GetData()->device_name);
514   }
515 
get_scope_name() const516   py::object get_scope_name() const {
517     return GetPyObject(&GetData()->scope_name);
518   }
set_scope_name(py::handle v)519   void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
520 
get_device_spec() const521   py::object get_device_spec() const {
522     return GetPyObject(&GetData()->device_spec);
523   }
set_device_spec(py::handle v)524   void set_device_spec(py::handle v) {
525     SetPyObject(v, &GetData()->device_spec);
526   }
527 
get_function_call_options() const528   py::object get_function_call_options() const {
529     return GetPyObject(&GetData()->function_call_options);
530   }
set_function_call_options(py::handle v)531   void set_function_call_options(py::handle v) {
532     SetPyObject(v, &GetData()->function_call_options);
533   }
534 
get_executor() const535   py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
set_executor(py::handle v)536   void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
537 
get_op_callbacks() const538   py::object get_op_callbacks() const {
539     return GetPyObject(&GetData()->op_callbacks);
540   }
set_op_callbacks(py::handle v)541   void set_op_callbacks(py::handle v) {
542     SetPyObject(v, &GetData()->op_callbacks);
543   }
544 
545  private:
GetData() const546   tensorflow::EagerContextThreadLocalData* GetData() const {
547     auto* result =
548         tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
549     if (!result) {
550       throw py::error_already_set();
551     }
552     return result;
553   }
554 
GetPyObject(tensorflow::Safe_PyObjectPtr * obj) const555   py::object GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
556     return pybind11::reinterpret_borrow<py::object>(obj->get());
557   }
558 
SetPyObject(py::handle value,tensorflow::Safe_PyObjectPtr * ptr)559   void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
560     Py_INCREF(value.ptr());
561     ptr->reset(value.ptr());
562   }
563 
564   PyObject* py_eager_context_;  // not owned (borrowed reference).
565 };
566 
567 }  // namespace
568 
569 // py::return_value_policy::reference is defined as specified by the
570 // pybind11 documents listed here.
571 // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
572 // This means that C++ maintains ownership of the object. We
573 // are only assigning this to functions that return opaque types.
574 
PYBIND11_MODULE(_pywrap_tfe,m)575 PYBIND11_MODULE(_pywrap_tfe, m) {
576   py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
577   py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
578                                                           "TFE_ContextOptions");
579   py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
580       m, "TFE_MonitoringCounter0");
581   py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
582       m, "TFE_MonitoringCounter1");
583   py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
584       m, "TFE_MonitoringCounter2");
585   py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
586       m, "TFE_MonitoringStringGauge0");
587   py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
588       m, "TFE_MonitoringStringGauge1");
589   py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
590       m, "TFE_MonitoringStringGauge2");
591   py::class_<TFE_MonitoringStringGauge3> TFE_MonitoringStringGauge3_class(
592       m, "TFE_MonitoringStringGauge3");
593   py::class_<TFE_MonitoringStringGauge4> TFE_MonitoringStringGauge4_class(
594       m, "TFE_MonitoringStringGauge4");
595   py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
596       m, "TFE_MonitoringIntGauge0");
597   py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
598       m, "TFE_MonitoringIntGauge1");
599   py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
600       m, "TFE_MonitoringIntGauge2");
601   py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
602       m, "TFE_MonitoringBoolGauge0");
603   py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
604       m, "TFE_MonitoringBoolGauge1");
605   py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
606       m, "TFE_MonitoringBoolGauge2");
607   py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
608       m, "TFE_MonitoringCounterCell");
609   py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
610       m, "TFE_MonitoringIntGaugeCell");
611   py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
612       m, "TFE_MonitoringStringGaugeCell");
613   py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
614       m, "TFE_MonitoringBoolGaugeCell");
615   py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
616       m, "TFE_MonitoringSamplerCell");
617   py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
618       m, "TFE_MonitoringBuckets");
619   py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
620       m, "TFE_MonitoringSampler0");
621   py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
622       m, "TFE_MonitoringSampler1");
623   py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
624       m, "TFE_MonitoringSampler2");
625   py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
626       m, "TFE_CancellationManager");
627 
628   py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
629   py::class_<TF_Function> TF_Function_class(m, "TF_Function");
630 
631   m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
632     return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
633   });
634   m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
635     return tensorflow::PyoOrThrow(
636         TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
637   });
638 
639   m.def("TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
640     tensorflow::Device* matched_device =
641         tensorflow::GetMatchedDevice(ctx, device_name);
642 
643     tensorflow::AllocatorAttributes attrs;
644     tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
645 
646     if (absl::optional<tensorflow::AllocatorStats> stats =
647             allocator->GetStats()) {
648       return std::map<std::string, int64_t>{{"current", stats->bytes_in_use},
649                                             {"peak", stats->peak_bytes_in_use}};
650     }
651 
652     tensorflow::ThrowValueError(
653         absl::StrFormat("Allocator stats not available for device '%s'",
654                         device_name)
655             .c_str());
656   });
657 
658   m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) {
659     tensorflow::Device* matched_device =
660         tensorflow::GetMatchedDevice(ctx, device_name);
661 
662     tensorflow::AllocatorAttributes attrs;
663     tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
664 
665     if (!allocator->ClearStats()) {
666       tensorflow::ThrowValueError(
667           absl::StrFormat("Cannot reset memory stats for device '%s'",
668                           device_name)
669               .c_str());
670     }
671   });
672 
673   // XLA Eager Logic
674   m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
675   m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
676   m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
677   m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
678   m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
679   m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
680   m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
681 
682   // MLIR Logic
683   m.def("TF_IsMlirBridgeEnabled", [] {
684     // Since python protobuf enums are integers, cast to an integer before
685     // returning the enum to python.
686     return static_cast<int32_t>(
687         tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
688   });
689   m.def("TF_EnableMlirBridge", [](bool enabled) {
690     tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
691         enabled
692             ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
693             : tensorflow::ConfigProto::Experimental::
694                   MLIR_BRIDGE_ROLLOUT_DISABLED;
695   });
696   m.def("TF_EnableXlaDevices", [] {
697     tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
698   });
699   m.def("TF_ResetJitCompilerFlags",
700         [] { tensorflow::ResetJitCompilerFlags(); });
701 
702   // TFE_Context Logic
703   m.def(
704       "TFE_NewContext",
705       [](const TFE_ContextOptions* opts) {
706         tensorflow::Safe_TF_StatusPtr status =
707             tensorflow::make_safe(TF_NewStatus());
708         TFE_Context* context = TFE_NewContext(opts, status.get());
709         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
710         return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
711       },
712       py::return_value_policy::reference);
713   m.def("TFE_DeleteContext", [](py::handle& o) {
714     TFE_DeleteContext(tensorflow::InputTFE_Context(o));
715   });
716   m.def(
717       "TFE_ContextListDevices",
718       [](py::handle& o) {
719         tensorflow::Safe_TF_StatusPtr status =
720             tensorflow::make_safe(TF_NewStatus());
721         auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
722                                              status.get());
723         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
724         return output;
725       },
726       py::return_value_policy::reference);
727   m.def(
728       "TFE_SetLogicalCpuDevices",
729       [](py::handle& ctx, int num_cpus, const char* prefix) {
730         tensorflow::Safe_TF_StatusPtr status =
731             tensorflow::make_safe(TF_NewStatus());
732         TFE_SetLogicalCpuDevices(tensorflow::InputTFE_Context(ctx), num_cpus,
733                                  prefix, status.get());
734         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
735       },
736       py::return_value_policy::reference);
737   m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
738     TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
739   });
740   m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
741     tensorflow::Safe_TF_StatusPtr status =
742         tensorflow::make_safe(TF_NewStatus());
743     TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
744                            status.get());
745     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
746   });
747   m.def("TFE_ContextAddFunctionDef",
748         [](py::handle& ctx, const char* serialized_function_def, size_t size) {
749           tensorflow::Safe_TF_StatusPtr status =
750               tensorflow::make_safe(TF_NewStatus());
751           TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
752                                     serialized_function_def, size,
753                                     status.get());
754           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
755         });
756   m.def("TFE_ContextGetFunctionDef",
757         [](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
758           tensorflow::Safe_TF_StatusPtr status =
759               tensorflow::make_safe(TF_NewStatus());
760           TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
761                                     function_name, &buf, status.get());
762           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
763         });
764   m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
765     tensorflow::Safe_TF_StatusPtr status =
766         tensorflow::make_safe(TF_NewStatus());
767     TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
768                               status.get());
769     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
770   });
771   m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
772     tensorflow::Safe_TF_StatusPtr status =
773         tensorflow::make_safe(TF_NewStatus());
774     auto output =
775         TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
776     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
777     return output;
778   });
779   m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
780     return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
781         ->ListFunctionNames();
782   });
783   m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
784     TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
785   });
786   m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
787     TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
788   });
789   m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
790     TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
791   });
792   m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
793     TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
794   });
795   m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
796     tensorflow::Safe_TF_StatusPtr status =
797         tensorflow::make_safe(TF_NewStatus());
798     TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
799                                  status.get());
800     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
801   });
802   m.def("TFE_ContextClearCaches", [](py::handle& o) {
803     TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
804   });
805   m.def("TFE_GetContextId", [](py::handle& ctx) {
806     return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
807   });
808   m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
809     return TFE_ContextGetDevicePlacementPolicy(
810         tensorflow::InputTFE_Context(ctx));
811   });
812   m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
813         [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
814           TFE_ContextSetThreadLocalDevicePlacementPolicy(
815               tensorflow::InputTFE_Context(ctx), policy);
816         });
817   m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
818                                       py::bytes proto) {
819     tensorflow::Safe_TF_StatusPtr status =
820         tensorflow::make_safe(TF_NewStatus());
821     tensorflow::Safe_TF_BufferPtr buf =
822         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
823     TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
824                             buf.get()->data, buf.get()->length, status.get());
825     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
826   });
827   m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
828                                          py::bytes proto) {
829     tensorflow::Safe_TF_StatusPtr status =
830         tensorflow::make_safe(TF_NewStatus());
831     tensorflow::Safe_TF_BufferPtr buf =
832         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
833     Py_BEGIN_ALLOW_THREADS;
834     TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
835                                keep_alive_secs, buf.get()->data,
836                                buf.get()->length, status.get());
837     Py_END_ALLOW_THREADS;
838     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
839   });
840   m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
841     tensorflow::Safe_TF_StatusPtr status =
842         tensorflow::make_safe(TF_NewStatus());
843     bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
844                                         worker_name, status.get());
845     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
846     return output;
847   });
848   m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
849     tensorflow::Safe_TF_StatusPtr status =
850         tensorflow::make_safe(TF_NewStatus());
851     // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
852     Py_BEGIN_ALLOW_THREADS;
853     TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
854     Py_END_ALLOW_THREADS;
855     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
856   });
857   m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
858     tensorflow::Safe_TF_StatusPtr status =
859         tensorflow::make_safe(TF_NewStatus());
860     // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
861     Py_BEGIN_ALLOW_THREADS;
862     TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
863     Py_END_ALLOW_THREADS;
864     // NOTE: different from TFE_ContextSyncExecutors that raises potential
865     // errors, deliberately ignore executor statuses in cleanup.
866   });
867   m.def(
868       "TFE_InsertConfigKeyValue",
869       [](py::handle& ctx, const char* config_key, const char* config_value) {
870         tensorflow::Safe_TF_StatusPtr status =
871             tensorflow::make_safe(TF_NewStatus());
872         Py_BEGIN_ALLOW_THREADS;
873         TFE_InsertConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
874                                  config_value, status.get());
875         Py_END_ALLOW_THREADS;
876         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
877       },
878       py::return_value_policy::reference);
879   m.def(
880       "TFE_GetConfigKeyValue",
881       [](py::handle& ctx, const char* config_key, TF_Buffer& config_value) {
882         tensorflow::Safe_TF_StatusPtr status =
883             tensorflow::make_safe(TF_NewStatus());
884         Py_BEGIN_ALLOW_THREADS;
885         TFE_GetConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
886                               &config_value, status.get());
887         Py_END_ALLOW_THREADS;
888         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
889       },
890       py::return_value_policy::reference);
891   m.def(
892       "TFE_DeleteConfigKeyValue",
893       [](py::handle& ctx, const char* config_key) {
894         tensorflow::Safe_TF_StatusPtr status =
895             tensorflow::make_safe(TF_NewStatus());
896         Py_BEGIN_ALLOW_THREADS;
897         TFE_DeleteConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
898                                  status.get());
899         Py_END_ALLOW_THREADS;
900         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
901       },
902       py::return_value_policy::reference);
903   m.def(
904       "TFE_ReportErrorToCluster",
905       [](py::handle& ctx, int error_code, const char* error_message) {
906         tensorflow::Safe_TF_StatusPtr status =
907             tensorflow::make_safe(TF_NewStatus());
908         TFE_ReportErrorToCluster(tensorflow::InputTFE_Context(ctx), error_code,
909                                  error_message, status.get());
910         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
911       },
912       py::return_value_policy::reference);
913   m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
914     tensorflow::Safe_TF_StatusPtr status =
915         tensorflow::make_safe(TF_NewStatus());
916     TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
917                                       status.get());
918   });
919   m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
920     tensorflow::Safe_TF_StatusPtr status =
921         tensorflow::make_safe(TF_NewStatus());
922     TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
923                                       status.get());
924   });
925   m.def("TFE_ContextSetRunEagerOpAsFunction", [](py::handle& ctx, bool enable) {
926     tensorflow::Safe_TF_StatusPtr status =
927         tensorflow::make_safe(TF_NewStatus());
928     TFE_ContextSetRunEagerOpAsFunction(tensorflow::InputTFE_Context(ctx),
929                                        enable, status.get());
930   });
931   m.def("TFE_ContextSetJitCompileRewrite", [](py::handle& ctx, bool enable) {
932     tensorflow::Safe_TF_StatusPtr status =
933         tensorflow::make_safe(TF_NewStatus());
934     TFE_ContextSetJitCompileRewrite(tensorflow::InputTFE_Context(ctx), enable,
935                                     status.get());
936   });
937 
938   // TFE_Executor logic
939   m.def(
940       "TFE_NewExecutor",
941       [](const bool is_async, const bool enable_streaming_enqueue) {
942         TFE_Executor* exc = TFE_NewExecutor(is_async, enable_streaming_enqueue);
943         return exc;
944       },
945       py::return_value_policy::reference);
946   m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
947   m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
948   m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
949     tensorflow::Safe_TF_StatusPtr status =
950         tensorflow::make_safe(TF_NewStatus());
951     // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
952     Py_BEGIN_ALLOW_THREADS;
953     TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
954     Py_END_ALLOW_THREADS;
955     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
956   });
957   m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
958   m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
959                                               TFE_Executor& exc) {
960     TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
961   });
962   m.def(
963       "TFE_ContextGetExecutorForThread",
964       [](py::handle& o) {
965         return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
966       },
967       py::return_value_policy::reference);
968 
969   m.def("TFE_OpNameGetAttrType",
970         [](py::handle& ctx, const char* op_or_function_name,
971            const char* attr_name) {
972           int temp = 0;
973           unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
974           tensorflow::Safe_TF_StatusPtr status =
975               tensorflow::make_safe(TF_NewStatus());
976           auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
977                                               op_or_function_name, attr_name,
978                                               is_list, status.get());
979           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
980 #if PY_MAJOR_VERSION < 3
981           PyObject* output_pyo = PyInt_FromLong(output);
982 #else
983           PyObject* output_pyo = PyLong_FromLong(output);
984 #endif
985           if (*is_list == 1) {
986             PyObject* list = PyList_New(1);
987             PyList_SetItem(list, 0, output_pyo);
988             return tensorflow::PyoOrThrow(list);
989           }
990           return tensorflow::PyoOrThrow(output_pyo);
991         });
992   m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
993     return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
994   });
995   m.def("TFE_Py_PackEagerTensors",
996         [](const py::handle& context, const py::handle& handles) {
997           return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
998         });
999   m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
1000   m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
1001     return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
1002   });
1003   m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
1004     return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
1005   });
1006   m.def("TFE_Py_Execute",
1007         [](const py::handle& context, const char* device_name,
1008            const char* op_name, const py::handle& inputs,
1009            const py::handle& attrs, const py::handle& num_outputs) {
1010           return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
1011               context, device_name, op_name, inputs, attrs.ptr(), nullptr,
1012               num_outputs);
1013         });
1014   m.def(
1015       "TFE_Py_ExecuteCancelable",
1016       [](const py::handle& context, const char* device_name,
1017          const char* op_name, const py::handle& inputs, const py::handle& attrs,
1018          tensorflow::CancellationManager& cancellation_manager,
1019          const py::handle& num_outputs) {
1020         return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
1021             context, device_name, op_name, inputs, attrs.ptr(),
1022             &cancellation_manager, num_outputs);
1023       });
1024   m.def("TFE_Py_FastPathExecute", [](const py::args args) {
1025     // TFE_Py_FastPathExecute requires error checking prior to returning.
1026     return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
1027   });
1028   m.def("TFE_Py_RecordGradient",
1029         [](const py::handle& op_name, const py::handle& inputs,
1030            const py::handle& attrs, const py::handle& results,
1031            const py::handle& forward_pass_name_scope) {
1032           return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
1033               op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
1034               forward_pass_name_scope.ptr()));
1035         });
1036   m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
1037 
1038   // TFE_Py_Tape Logic
1039   m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
1040                                 const py::handle& watch_accessed_variables) {
1041     return tensorflow::PyoOrThrow(
1042         TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
1043   });
1044   m.def("TFE_Py_TapeSetAdd",
1045         [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
1046   m.def("TFE_Py_TapeSetRemove",
1047         [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
1048   m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
1049   m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
1050   m.def("TFE_Py_TapeSetIsStopped",
1051         []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
1052   m.def("TFE_Py_TapeSetIsEmpty",
1053         []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
1054   m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
1055     return tensorflow::PyoOrThrow(
1056         TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
1057   });
1058   m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
1059     return tensorflow::PyoOrThrow(
1060         TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
1061   });
1062   m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
1063   m.def("TFE_Py_TapeSetRecordOperation",
1064         [](const py::handle& op_type, const py::handle& output_tensors,
1065            const py::handle& input_tensors, const py::handle& backward_function,
1066            const py::handle& forward_function) {
1067           return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
1068               op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1069               backward_function.ptr(), forward_function.ptr()));
1070         });
1071   m.def(
1072       "TFE_Py_TapeSetRecordOperationBackprop",
1073       [](const py::handle& op_type, const py::handle& output_tensors,
1074          const py::handle& input_tensors, const py::handle& backward_function) {
1075         return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
1076             op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1077             backward_function.ptr()));
1078       });
1079   m.def(
1080       "TFE_Py_TapeSetRecordOperationForwardprop",
1081       [](const py::handle& op_type, const py::handle& output_tensors,
1082          const py::handle& input_tensors, const py::handle& backward_function,
1083          const py::handle& forwardprop_output_indices) {
1084         return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
1085             op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1086             backward_function.ptr(), forwardprop_output_indices.ptr()));
1087       });
1088   m.def("TFE_Py_TapeGradient",
1089         [](const py::handle& tape, const py::handle& target,
1090            const py::handle& sources, const py::handle& output_gradients,
1091            const py::handle& sources_raw,
1092            const py::handle& unconnected_gradients) {
1093           tensorflow::Safe_TF_StatusPtr status =
1094               tensorflow::make_safe(TF_NewStatus());
1095           PyObject* output = TFE_Py_TapeGradient(
1096               tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
1097               sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
1098           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1099           return tensorflow::PyoOrThrow(output);
1100         });
1101 
1102   m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
1103     TFE_Py_TapeVariableAccessed(variable.ptr());
1104   });
1105   m.def("TFE_Py_TapeWatch",
1106         [](const py::handle& tape, const py::handle& tensor) {
1107           TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
1108         });
1109   m.def("TFE_Py_TapeWatchVariable",
1110         [](const py::handle& tape, const py::handle& variable) {
1111           TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
1112         });
1113   m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
1114     return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
1115   });
1116 
1117   // TFE_Py_VariableWatcher logic.
1118   m.def("TFE_Py_VariableWatcherNew",
1119         []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
1120   m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
1121     TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
1122   });
1123   m.def("TFE_Py_VariableWatcherVariableAccessed",
1124         [](const py::handle& variable) {
1125           TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
1126         });
1127   m.def("TFE_Py_VariableWatcherWatchedVariables",
1128         [](const py::handle& variable_watcher) {
1129           return tensorflow::PyoOrThrow(
1130               TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
1131         });
1132 
1133   // TFE_Py_ForwardAccumulator logic.
1134   m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
1135     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
1136   });
1137 
1138   m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
1139     return tensorflow::PyoOrThrow(
1140         TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
1141   });
1142   m.def("TFE_Py_ForwardAccumulatorSetRemove",
1143         [](const py::handle& accumulator) {
1144           TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
1145         });
1146 
1147   m.def("TFE_Py_ForwardAccumulatorWatch",
1148         [](const py::handle& accumulator, const py::handle& tensor,
1149            const py::handle& tangent) {
1150           TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
1151                                          tangent.ptr());
1152         });
1153   m.def("TFE_Py_ForwardAccumulatorJVP",
1154         [](const py::handle& accumulator, const py::handle& tensor) {
1155           return tensorflow::PyoOrThrow(
1156               TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
1157         });
1158   m.def("TFE_Py_ForwardAccumulatorPushState", []() {
1159     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
1160   });
1161   m.def("TFE_Py_ForwardAccumulatorPopState", []() {
1162     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
1163   });
1164   m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
1165     return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
1166   });
1167 
1168   // TFE_ContextOptions Logic
1169   m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
1170         py::return_value_policy::reference);
1171   m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
1172                                           py::bytes proto) {
1173     tensorflow::Safe_TF_StatusPtr status =
1174         tensorflow::make_safe(TF_NewStatus());
1175     tensorflow::Safe_TF_BufferPtr buf =
1176         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1177     TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
1178                                 status.get());
1179     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1180   });
1181   m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
1182         &TFE_ContextOptionsSetDevicePlacementPolicy);
1183   m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
1184   m.def("TFE_ContextOptionsSetTfrtDistributedRuntime",
1185         &TFE_ContextOptionsSetTfrtDistributedRuntime);
1186   // Experimental feature, intentionally not exposed as a C API yet.
1187   m.def("TFE_ContextOptionsSetRunEagerOpAsFunction",
1188         [](TFE_ContextOptions* options, bool run_eager_op_as_function) {
1189           options->run_eager_op_as_function = run_eager_op_as_function;
1190         });
1191   m.def("TFE_ContextOptionsSetJitCompileRewrite",
1192         [](TFE_ContextOptions* options, bool jit_compile_rewrite) {
1193           options->jit_compile_rewrite = jit_compile_rewrite;
1194         });
1195   m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
1196   m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
1197         py::return_value_policy::reference);
1198 
1199   // TFE_Py_TensorShape Logic
1200   m.def("TFE_Py_TensorShapeSlice",
1201         [](const py::handle& tensors, int slice_dim) {
1202           return tensorflow::PyoOrThrow(
1203               TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
1204         });
1205   m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
1206                                          int slice_dim) {
1207     return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
1208   });
1209   m.def("TFE_Py_EnableInteractivePythonLogging",
1210         &TFE_Py_EnableInteractivePythonLogging);
1211 
1212   // Additional Context Logic
1213   m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
1214     return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
1215   });
1216   m.def("TFE_Py_SetCEagerContext", [](const py::handle& ctx) {
1217     // TODO(mdan): This cast might need rewriting to ImmediateExecutionContext.
1218     tensorflow::SetCEagerContext(reinterpret_cast<tensorflow::EagerContext*>(
1219         tensorflow::InputTFE_Context(ctx)));
1220   });
1221   m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
1222     return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
1223   });
1224   m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
1225     tensorflow::Safe_TF_StatusPtr status =
1226         tensorflow::make_safe(TF_NewStatus());
1227     tensorflow::Safe_TF_BufferPtr buf =
1228         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1229     TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
1230                             buf.get()->length, status.get());
1231     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1232   });
1233   m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
1234                                      const char* message) {
1235     tensorflow::Safe_TF_StatusPtr status =
1236         tensorflow::make_safe(TF_NewStatus());
1237     TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
1238     TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
1239   });
1240   m.def("TFE_CollectiveOpsCheckPeerHealth",
1241         [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
1242           tensorflow::Safe_TF_StatusPtr status =
1243               tensorflow::make_safe(TF_NewStatus());
1244           TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
1245                                            task, timeout_in_ms, status.get());
1246           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1247         });
1248   m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
1249   m.def("TF_ListPluggablePhysicalDevices",
1250         &tensorflow::TF_ListPluggablePhysicalDevices);
1251   m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
1252   m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
1253         py::return_value_policy::reference);
1254   m.def("TF_DeviceListCount", &TF_DeviceListCount);
1255   m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1256     tensorflow::Safe_TF_StatusPtr status =
1257         tensorflow::make_safe(TF_NewStatus());
1258     auto output = TF_DeviceListName(list, index, status.get());
1259     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1260     return output;
1261   });
1262   m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1263     tensorflow::Safe_TF_StatusPtr status =
1264         tensorflow::make_safe(TF_NewStatus());
1265     auto output = TF_DeviceListType(list, index, status.get());
1266     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1267     return output;
1268   });
1269 
1270   m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
1271 
1272   // TFE_MonitoringCounter Logic
1273   m.def("TFE_MonitoringCounterCellIncrementBy",
1274         &TFE_MonitoringCounterCellIncrementBy);
1275   m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
1276   m.def(
1277       "TFE_MonitoringNewCounter0",
1278       [](const char* name, const char* description) {
1279         tensorflow::Safe_TF_StatusPtr status =
1280             tensorflow::make_safe(TF_NewStatus());
1281         auto output =
1282             TFE_MonitoringNewCounter0(name, status.get(), description);
1283         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1284         return output;
1285       },
1286       py::return_value_policy::reference);
1287   m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
1288         py::return_value_policy::reference);
1289   m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
1290         py::return_value_policy::reference);
1291   m.def(
1292       "TFE_MonitoringNewCounter1",
1293       [](const char* name, const char* description, const char* label1) {
1294         tensorflow::Safe_TF_StatusPtr status =
1295             tensorflow::make_safe(TF_NewStatus());
1296         auto output =
1297             TFE_MonitoringNewCounter1(name, status.get(), description, label1);
1298         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1299         return output;
1300       },
1301       py::return_value_policy::reference);
1302   m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
1303         py::return_value_policy::reference);
1304   m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
1305         py::return_value_policy::reference);
1306   m.def(
1307       "TFE_MonitoringNewCounter2",
1308       [](const char* name, const char* description, const char* label1,
1309          const char* label2) {
1310         tensorflow::Safe_TF_StatusPtr status =
1311             tensorflow::make_safe(TF_NewStatus());
1312         auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
1313                                                 label1, label2);
1314         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1315         return output;
1316       },
1317       py::return_value_policy::reference);
1318   m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
1319         py::return_value_policy::reference);
1320   m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
1321         py::return_value_policy::reference);
1322 
1323   // TFE_MonitoringIntGauge Logic
1324   m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
1325   m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
1326   m.def(
1327       "TFE_MonitoringNewIntGauge0",
1328       [](const char* name, const char* description) {
1329         tensorflow::Safe_TF_StatusPtr status =
1330             tensorflow::make_safe(TF_NewStatus());
1331         auto output =
1332             TFE_MonitoringNewIntGauge0(name, status.get(), description);
1333         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1334         return output;
1335       },
1336       py::return_value_policy::reference);
1337   m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
1338         py::return_value_policy::reference);
1339   m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
1340         py::return_value_policy::reference);
1341   m.def(
1342       "TFE_MonitoringNewIntGauge1",
1343       [](const char* name, const char* description, const char* label1) {
1344         tensorflow::Safe_TF_StatusPtr status =
1345             tensorflow::make_safe(TF_NewStatus());
1346         auto output =
1347             TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
1348         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1349         return output;
1350       },
1351       py::return_value_policy::reference);
1352   m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
1353         py::return_value_policy::reference);
1354   m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
1355         py::return_value_policy::reference);
1356   m.def(
1357       "TFE_MonitoringNewIntGauge2",
1358       [](const char* name, const char* description, const char* label1,
1359          const char* label2) {
1360         tensorflow::Safe_TF_StatusPtr status =
1361             tensorflow::make_safe(TF_NewStatus());
1362         auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
1363                                                  description, label1, label2);
1364         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1365         return output;
1366       },
1367       py::return_value_policy::reference);
1368   m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
1369         py::return_value_policy::reference);
1370   m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
1371         py::return_value_policy::reference);
1372   m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
1373   m.def("TFE_MonitoringStringGaugeCellValue",
1374         &TFE_MonitoringStringGaugeCellValue);
1375   m.def(
1376       "TFE_MonitoringNewStringGauge0",
1377       [](const char* name, const char* description) {
1378         tensorflow::Safe_TF_StatusPtr status =
1379             tensorflow::make_safe(TF_NewStatus());
1380         auto output =
1381             TFE_MonitoringNewStringGauge0(name, status.get(), description);
1382         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1383         return output;
1384       },
1385       py::return_value_policy::reference);
1386 
1387   // TFE_MonitoringStringGauge Logic
1388   m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
1389   m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
1390         py::return_value_policy::reference);
1391   m.def(
1392       "TFE_MonitoringNewStringGauge1",
1393       [](const char* name, const char* description, const char* label1) {
1394         tensorflow::Safe_TF_StatusPtr status =
1395             tensorflow::make_safe(TF_NewStatus());
1396         auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
1397                                                     description, label1);
1398         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1399         return output;
1400       },
1401       py::return_value_policy::reference);
1402   m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
1403   m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
1404         py::return_value_policy::reference);
1405   m.def(
1406       "TFE_MonitoringNewStringGauge2",
1407       [](const char* name, const char* description, const char* label1,
1408          const char* label2) {
1409         tensorflow::Safe_TF_StatusPtr status =
1410             tensorflow::make_safe(TF_NewStatus());
1411         auto output = TFE_MonitoringNewStringGauge2(
1412             name, status.get(), description, label1, label2);
1413         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1414         return output;
1415       },
1416       py::return_value_policy::reference);
1417   m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
1418   m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
1419         py::return_value_policy::reference);
1420 
1421   m.def(
1422       "TFE_MonitoringNewStringGauge3",
1423       [](const char* name, const char* description, const char* label1,
1424          const char* label2, const char* label3) {
1425         tensorflow::Safe_TF_StatusPtr status =
1426             tensorflow::make_safe(TF_NewStatus());
1427         auto output = TFE_MonitoringNewStringGauge3(
1428             name, status.get(), description, label1, label2, label3);
1429         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1430         return output;
1431       },
1432       py::return_value_policy::reference);
1433   m.def("TFE_MonitoringDeleteStringGauge3", &TFE_MonitoringDeleteStringGauge3);
1434   m.def("TFE_MonitoringGetCellStringGauge3", &TFE_MonitoringGetCellStringGauge3,
1435         py::return_value_policy::reference);
1436 
1437   m.def(
1438       "TFE_MonitoringNewStringGauge4",
1439       [](const char* name, const char* description, const char* label1,
1440          const char* label2, const char* label3, const char* label4) {
1441         tensorflow::Safe_TF_StatusPtr status =
1442             tensorflow::make_safe(TF_NewStatus());
1443         auto output = TFE_MonitoringNewStringGauge4(
1444             name, status.get(), description, label1, label2, label3, label4);
1445         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1446         return output;
1447       },
1448       py::return_value_policy::reference);
1449   m.def("TFE_MonitoringDeleteStringGauge4", &TFE_MonitoringDeleteStringGauge4);
1450   m.def("TFE_MonitoringGetCellStringGauge4", &TFE_MonitoringGetCellStringGauge4,
1451         py::return_value_policy::reference);
1452 
1453   // TFE_MonitoringBoolGauge Logic
1454   m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
1455   m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
1456   m.def(
1457       "TFE_MonitoringNewBoolGauge0",
1458       [](const char* name, const char* description) {
1459         tensorflow::Safe_TF_StatusPtr status =
1460             tensorflow::make_safe(TF_NewStatus());
1461         auto output =
1462             TFE_MonitoringNewBoolGauge0(name, status.get(), description);
1463         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1464         return output;
1465       },
1466       py::return_value_policy::reference);
1467   m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
1468         py::return_value_policy::reference);
1469   m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
1470         py::return_value_policy::reference);
1471   m.def(
1472       "TFE_MonitoringNewBoolGauge1",
1473       [](const char* name, const char* description, const char* label1) {
1474         tensorflow::Safe_TF_StatusPtr status =
1475             tensorflow::make_safe(TF_NewStatus());
1476         auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
1477                                                   description, label1);
1478         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1479         return output;
1480       },
1481       py::return_value_policy::reference);
1482   m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
1483         py::return_value_policy::reference);
1484   m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
1485         py::return_value_policy::reference);
1486   m.def(
1487       "TFE_MonitoringNewBoolGauge2",
1488       [](const char* name, const char* description, const char* label1,
1489          const char* label2) {
1490         tensorflow::Safe_TF_StatusPtr status =
1491             tensorflow::make_safe(TF_NewStatus());
1492         auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
1493                                                   description, label1, label2);
1494         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1495         return output;
1496       },
1497       py::return_value_policy::reference);
1498   m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
1499         py::return_value_policy::reference);
1500   m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
1501         py::return_value_policy::reference);
1502 
1503   // TFE_MonitoringSampler Logic
1504   m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
1505   m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
1506   m.def("TFE_MonitoringNewExponentialBuckets",
1507         &TFE_MonitoringNewExponentialBuckets,
1508         py::return_value_policy::reference);
1509   m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
1510         py::return_value_policy::reference);
1511   m.def(
1512       "TFE_MonitoringNewSampler0",
1513       [](const char* name, TFE_MonitoringBuckets* buckets,
1514          const char* description) {
1515         tensorflow::Safe_TF_StatusPtr status =
1516             tensorflow::make_safe(TF_NewStatus());
1517         auto output =
1518             TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
1519         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1520         return output;
1521       },
1522       py::return_value_policy::reference);
1523   m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
1524         py::return_value_policy::reference);
1525   m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
1526         py::return_value_policy::reference);
1527   m.def(
1528       "TFE_MonitoringNewSampler1",
1529       [](const char* name, TFE_MonitoringBuckets* buckets,
1530          const char* description, const char* label1) {
1531         tensorflow::Safe_TF_StatusPtr status =
1532             tensorflow::make_safe(TF_NewStatus());
1533         auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
1534                                                 description, label1);
1535         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1536         return output;
1537       },
1538       py::return_value_policy::reference);
1539   m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
1540         py::return_value_policy::reference);
1541   m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
1542         py::return_value_policy::reference);
1543   m.def(
1544       "TFE_MonitoringNewSampler2",
1545       [](const char* name, TFE_MonitoringBuckets* buckets,
1546          const char* description, const char* label1, const char* label2) {
1547         tensorflow::Safe_TF_StatusPtr status =
1548             tensorflow::make_safe(TF_NewStatus());
1549         auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
1550                                                 description, label1, label2);
1551         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1552         return output;
1553       },
1554       py::return_value_policy::reference);
1555   m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
1556         py::return_value_policy::reference);
1557   m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
1558         py::return_value_policy::reference);
1559 
1560   // TFE_CancellationManager Logic
1561   m.def("TFE_NewCancellationManager",
1562         []() { return new tensorflow::CancellationManager(); });
1563   m.def("TFE_CancellationManagerIsCancelled",
1564         &tensorflow::CancellationManager::IsCancelled);
1565   m.def("TFE_CancellationManagerStartCancel",
1566         &tensorflow::CancellationManager::StartCancel);
1567 
1568   m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
1569 
1570   // Util buffer helper functions
1571   m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
1572         py::return_value_policy::reference);
1573 
1574   // DLPack functions
1575   m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
1576     PyObject* eager_tensor_pyobject_ptr = o.ptr();
1577     tensorflow::Safe_TF_StatusPtr status =
1578         tensorflow::make_safe(TF_NewStatus());
1579 
1580     if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
1581       status->status = tensorflow::errors::InvalidArgument(
1582           "The argument to `to_dlpack` must be a TF tensor, not Python object");
1583       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1584     }
1585 
1586     TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
1587     void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
1588     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1589 
1590     py::capsule capsule(
1591         dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
1592           if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
1593             void* dlm_rptr =
1594                 PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
1595             if (dlm_rptr) {
1596               tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
1597               PyCapsule_SetDestructor(capsule, nullptr);
1598             }
1599           }
1600         });
1601     return capsule;
1602   });
1603 
1604   m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
1605                                     const py::handle& context) {
1606     tensorflow::Safe_TF_StatusPtr status =
1607         tensorflow::make_safe(TF_NewStatus());
1608     if (absl::string_view(pycapsule.name()) !=
1609         tensorflow::kDlTensorCapsuleName) {
1610       status->status = tensorflow::errors::InvalidArgument(
1611           "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
1612           "Note that a DLPack tensor may be consumed at most once.",
1613           absl::string_view(pycapsule.name()));
1614       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1615     }
1616 
1617     TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
1618         pycapsule, status.get(), tensorflow::InputTFE_Context(context));
1619 
1620     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1621 
1622     PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
1623     PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
1624 
1625     PyObject* pyhandle = EagerTensorFromHandle(thandle);
1626     return tensorflow::PyoOrThrow(pyhandle);
1627   });
1628 
1629   m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
1630                                           const py::capsule& device,
1631                                           const char* device_name,
1632                                           const py::capsule& device_info) {
1633     tensorflow::Safe_TF_StatusPtr status =
1634         tensorflow::make_safe(TF_NewStatus());
1635     if (absl::string_view(device.name()) != "TFE_CustomDevice") {
1636       status->status = tensorflow::errors::InvalidArgument(
1637           "Expected a capsule named 'TFE_CustomDevice' for the `device` "
1638           "argument, got ",
1639           absl::string_view(device.name()));
1640       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1641     }
1642     if (absl::string_view(device_info.name()) !=
1643         "TFE_CustomDevice_DeviceInfo") {
1644       status->status = tensorflow::errors::InvalidArgument(
1645           "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
1646           "the `device_info` argument, got ",
1647           absl::string_view(device_info.name()));
1648       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1649     }
1650     // TFE_RegisterCustomDevice takes ownership
1651     PyCapsule_SetDestructor(device_info.ptr(), nullptr);
1652     TFE_RegisterCustomDevice(
1653         tensorflow::InputTFE_Context(context),
1654         *reinterpret_cast<TFE_CustomDevice*>(
1655             PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
1656         device_name,
1657         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
1658         status.get());
1659     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1660   });
1661 
1662   py::class_<EagerContextThreadLocalDataWrapper>(m,
1663                                                  "EagerContextThreadLocalData")
1664       .def(py::init<py::handle, py::handle, py::handle>(),
1665            py::arg("py_eager_context"), py::arg("is_eager"),
1666            py::arg("device_spec"))
1667       .def_property("is_eager",
1668                     &EagerContextThreadLocalDataWrapper::get_is_eager,
1669                     &EagerContextThreadLocalDataWrapper::set_is_eager)
1670       .def_property(
1671           "invoking_op_callbacks",
1672           &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
1673           &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
1674       .def_property("device_name",
1675                     &EagerContextThreadLocalDataWrapper::get_device_name,
1676                     &EagerContextThreadLocalDataWrapper::set_device_name)
1677       .def_property("scope_name",
1678                     &EagerContextThreadLocalDataWrapper::get_scope_name,
1679                     &EagerContextThreadLocalDataWrapper::set_scope_name)
1680       .def_property("device_spec",
1681                     &EagerContextThreadLocalDataWrapper::get_device_spec,
1682                     &EagerContextThreadLocalDataWrapper::set_device_spec)
1683       .def_property(
1684           "function_call_options",
1685           &EagerContextThreadLocalDataWrapper::get_function_call_options,
1686           &EagerContextThreadLocalDataWrapper::set_function_call_options)
1687       .def_property("executor",
1688                     &EagerContextThreadLocalDataWrapper::get_executor,
1689                     &EagerContextThreadLocalDataWrapper::set_executor)
1690       .def_property("op_callbacks",
1691                     &EagerContextThreadLocalDataWrapper::get_op_callbacks,
1692                     &EagerContextThreadLocalDataWrapper::set_op_callbacks);
1693 
1694   // C API Enum
1695 
1696   py::enum_<TFE_ContextDevicePlacementPolicy>(
1697       m, "TFE_ContextDevicePlacementPolicy")
1698       .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
1699       .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
1700       .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
1701       .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
1702              TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
1703       .export_values();
1704 
1705   py::enum_<TF_AttrType>(m, "TF_AttrType")
1706       .value("TF_ATTR_STRING", TF_ATTR_STRING)
1707       .value("TF_ATTR_INT", TF_ATTR_INT)
1708       .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
1709       .value("TF_ATTR_BOOL", TF_ATTR_BOOL)
1710       .value("TF_ATTR_TYPE", TF_ATTR_TYPE)
1711       .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
1712       .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
1713       .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
1714       .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
1715       .export_values();
1716 };
1717